17 İşlemeler f6ce01b3ea ... af24764e67

Yazar SHA1 Mesaj Tarih
  phillip.rothenbeck af24764e67 add png, pdf, and gif 1 yıl önce
  phillip.rothenbeck 3b05d7d641 add generalized data transformation algorithm 1 yıl önce
  phillip.rothenbeck 7555f1b41e clean up 1 yıl önce
  phillip.rothenbeck 5f34dd8418 clean up reduced 1 yıl önce
  phillip.rothenbeck 0a7b829650 add paper layout + scatter function 1 yıl önce
  phillip.rothenbeck c38d74dc4c add scaling, norm, optimizer and scheduler choosing choosing 1 yıl önce
  phillip.rothenbeck b724dac3ad add norms and scaling 1 yıl önce
  phillip.rothenbeck 3cd104c9ae training pipelines 1 yıl önce
  phillip.rothenbeck eefda9b551 plot skripts for thesis 1 yıl önce
  phillip.rothenbeck c005bdab3e seperatly train R_t for Germany 1 yıl önce
  phillip.rothenbeck 483b1f70b3 preprocess all data in one notebook 1 yıl önce
  phillip.rothenbeck 53d929930a add data 1 yıl önce
  phillip.rothenbeck e37bc59bbb dont push training results 1 yıl önce
  phillip.rothenbeck 56254e4fea delete uni-purpose graph file 1 yıl önce
  phillip.rothenbeck c6a9daaffa new dir structure in vis dir 1 yıl önce
  phillip.rothenbeck e96794c310 implement model 1 yıl önce
  phillip.rothenbeck 8152e69c62 create I dataset 1 yıl önce
100 değiştirilmiş dosya ile 1987 ekleme ve 2053 silme
  1. 11 1
      .gitignore
  2. 32 6
      data.ipynb
  3. 0 0
      datasets/I_data.csv
  4. 0 54
      generate_presi_graphs.py
  5. 52 0
      germany_training.py
  6. 109 0
      plot_datasets.py
  7. 302 0
      plot_results.py
  8. 21 0
      plot_training_metrics.py
  9. 55 8
      src/dataset.py
  10. 131 31
      src/dinn.py
  11. 224 17
      src/plotter.py
  12. 57 43
      src/preprocessing/synthetic_data.py
  13. 88 138
      src/preprocessing/transform_data.py
  14. 9 15
      src/problem.py
  15. 52 1576
      state_data_dinn_sir.ipynb
  16. 118 0
      states_training.py
  17. 254 74
      synth_dinn_reduced_sir.ipynb
  18. 472 90
      synth_dinn_sir.ipynb
  19. BIN
      visualizations/Baden_Wuerttemberg.png
  20. BIN
      visualizations/Bayern.png
  21. BIN
      visualizations/Berlin.png
  22. BIN
      visualizations/Brandenburg.png
  23. BIN
      visualizations/Bremen.png
  24. BIN
      visualizations/Hamburg.png
  25. BIN
      visualizations/Hessen.png
  26. BIN
      visualizations/Mecklenburg_Vorpommern.png
  27. BIN
      visualizations/Niedersachsen.png
  28. BIN
      visualizations/Nordrhein_Westfalen.png
  29. BIN
      visualizations/RKI_SIR_1.png
  30. BIN
      visualizations/RKI_SIR_10.png
  31. BIN
      visualizations/RKI_SIR_3.png
  32. BIN
      visualizations/RKI_SIR_5.png
  33. BIN
      visualizations/Rheinland_Pfalz.png
  34. BIN
      visualizations/SIRD_synth.png
  35. BIN
      visualizations/SIR_RKI_3_alpha.png
  36. BIN
      visualizations/SIR_RKI_3_animation.gif
  37. BIN
      visualizations/SIR_RKI_3_beta.png
  38. BIN
      visualizations/SIR_RKI_3_loss.png
  39. BIN
      visualizations/SIR_RKI_5_alpha.png
  40. BIN
      visualizations/SIR_RKI_5_animation.gif
  41. BIN
      visualizations/SIR_RKI_5_beta.png
  42. BIN
      visualizations/SIR_RKI_5_loss.png
  43. BIN
      visualizations/SI_synth.png
  44. BIN
      visualizations/Saarland.png
  45. BIN
      visualizations/Sachsen.png
  46. BIN
      visualizations/Schleswig_Holstein.png
  47. BIN
      visualizations/Thueringen.png
  48. BIN
      visualizations/animations/Baden_Wuerttemberg_animation.gif
  49. 0 0
      visualizations/animations/Baden_Wuerttemberg_synth_sir_animation.gif
  50. BIN
      visualizations/animations/Bayern_animation.gif
  51. 0 0
      visualizations/animations/Bayern_synth_sir_animation.gif
  52. BIN
      visualizations/animations/Berlin_animation.gif
  53. 0 0
      visualizations/animations/Berlin_synth_sir_animation.gif
  54. BIN
      visualizations/animations/Brandenburg_animation.gif
  55. 0 0
      visualizations/animations/Brandenburg_synth_sir_animation.gif
  56. BIN
      visualizations/animations/Bremen_animation.gif
  57. 0 0
      visualizations/animations/Bremen_synth_sir_animation.gif
  58. BIN
      visualizations/animations/Germany_animation.gif
  59. BIN
      visualizations/animations/Hamburg_animation.gif
  60. 0 0
      visualizations/animations/Hamburg_synth_sir_animation.gif
  61. BIN
      visualizations/animations/Hessen_animation.gif
  62. 0 0
      visualizations/animations/Hessen_synth_sir_animation.gif
  63. BIN
      visualizations/animations/Mecklenburg_Vorpommern_animation.gif
  64. 0 0
      visualizations/animations/Mecklenburg_Vorpommern_synth_sir_animation.gif
  65. BIN
      visualizations/animations/Niedersachsen_animation.gif
  66. 0 0
      visualizations/animations/Niedersachsen_synth_sir_animation.gif
  67. BIN
      visualizations/animations/Nordrhein_Westfalen_animation.gif
  68. 0 0
      visualizations/animations/Nordrhein_Westfalen_synth_sir_animation.gif
  69. BIN
      visualizations/animations/Rheinland_Pfalz_animation.gif
  70. 0 0
      visualizations/animations/Rheinland_Pfalz_synth_sir_animation.gif
  71. 0 0
      visualizations/animations/SIR_RKI_1_animation.gif
  72. BIN
      visualizations/animations/Saarland_animation.gif
  73. 0 0
      visualizations/animations/Saarland_synth_sir_animation.gif
  74. 0 0
      visualizations/animations/Sachsen_Anhalt_synth_sir_animation.gif
  75. BIN
      visualizations/animations/Sachsen_animation.gif
  76. 0 0
      visualizations/animations/Sachsen_synth_sir_animation.gif
  77. BIN
      visualizations/animations/Schleswig_Holstein_animation.gif
  78. 0 0
      visualizations/animations/Schleswig_Holstein_synth_sir_animation.gif
  79. 0 0
      visualizations/animations/Thueringen_synth_sir_animation.gif
  80. BIN
      visualizations/animations/synth_sir_animation.gif
  81. BIN
      visualizations/base_params_synth.png
  82. BIN
      visualizations/high_alpha_synth.png
  83. BIN
      visualizations/high_beta_synth.png
  84. BIN
      visualizations/low_alpha_synth.png
  85. BIN
      visualizations/low_beta_synth.png
  86. BIN
      visualizations/png_img/Baden_Wuerttemberg.png
  87. BIN
      visualizations/png_img/Baden_Wuerttemberg_loss.png
  88. BIN
      visualizations/png_img/Bayern.png
  89. BIN
      visualizations/png_img/Bayern_loss.png
  90. BIN
      visualizations/png_img/Berlin.png
  91. BIN
      visualizations/png_img/Berlin_loss.png
  92. BIN
      visualizations/png_img/Brandenburg.png
  93. BIN
      visualizations/png_img/Brandenburg_loss.png
  94. BIN
      visualizations/png_img/Bremen.png
  95. BIN
      visualizations/png_img/Bremen_loss.png
  96. BIN
      visualizations/png_img/Germany_loss.png
  97. BIN
      visualizations/png_img/Hamburg.png
  98. BIN
      visualizations/png_img/Hamburg_loss.png
  99. BIN
      visualizations/png_img/Hessen.png
  100. BIN
      visualizations/png_img/Hessen_loss.png

+ 11 - 1
.gitignore

@@ -1,5 +1,15 @@
-# push no pycache
+# push no pycache and config files
 **__pycache__**
 **__pycache__**
+*.json
+
+# push no batch skripts
+*.sh
+
+# push no training result data
+*.csv
+*.pdf
+*.out
+*.gif
 
 
 # push no RKI data
 # push no RKI data
 *_RKI_*.csv
 *_RKI_*.csv

Dosya farkı çok büyük olduğundan ihmal edildi
+ 32 - 6
data.ipynb


Dosya farkı çok büyük olduğundan ihmal edildi
+ 0 - 0
datasets/I_data.csv


+ 0 - 54
generate_presi_graphs.py

@@ -1,54 +0,0 @@
-import pandas as pd
-import matplotlib.pyplot as plt
-import matplotlib.dates as mdates
-from matplotlib import rcParams
-
-FONT_COLOR = '#595959'
-SUSCEPTIBLE = '#6399f7'
-INFECTIOUS = '#f56262'
-REMOVED = '#83eb5e'
-
-# rki data
-
-rki_data_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv'
-
-rki_data = pd.read_csv(rki_data_path)
-rki_data['Berichtsdatum'] = pd.to_datetime(rki_data['Berichtsdatum'], errors='coerce')
-specific_dates = rki_data[rki_data['Berichtsdatum'].dt.is_quarter_start]['Berichtsdatum']
-
-rcParams['font.family'] = 'Comfortaa'
-rcParams['font.size'] = 12
-
-rcParams['text.color'] = FONT_COLOR
-rcParams['axes.labelcolor'] = FONT_COLOR
-rcParams['xtick.color'] = FONT_COLOR
-rcParams['ytick.color'] = FONT_COLOR
-
-slide3 = plt.figure(figsize=(12,6))
-ax = slide3.add_subplot(111, facecolor='#dddddd', axisbelow=True)
-ax.set_facecolor('xkcd:white')
-
-ax.plot(rki_data['Berichtsdatum'], rki_data['Faelle_gesamt'], label='infections', c=INFECTIOUS, lw=3)
-ax.plot(rki_data['Berichtsdatum'], rki_data['Todesfaelle_gesamt'], label='death cases', c=REMOVED, lw=3)
-
-plt.yscale('log')
-plt.ylabel('amount of poeple')
-plt.xlabel('time')
-plt.title('Accumulated cases (RKI Data)')
-ax.yaxis.set_tick_params(length=0)
-
-leg = plt.legend()
-
-ax.yaxis.set_tick_params(length=0, which='both')
-ax.xaxis.set_tick_params(length=0, which='both')
-ax.grid(which='major', c='black', lw=0.2, ls='-')
-
-for spine in ('top', 'right', 'bottom', 'left'):
-    ax.spines[spine].set_visible(False)
-
-plt.gca().set_xticks(specific_dates)
-plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
-
-plt.gcf().autofmt_xdate(rotation=45, ha='center')
-
-slide3.savefig('visualizations/slide3.png', transparent=True)

+ 52 - 0
germany_training.py

@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import csv
+
+from src.dataset import PandemicDataset, Norms
+from src.problem import ReducedSIRProblem
+from src.dinn import DINN, Scheduler, Activation
+
+ALPHA = [1/14, 1/5]
+NORM = [Norms.POPULATION, Norms.CONSTANT]
+
+ITERATIONS = 10
+
+for iteration in range(ITERATIONS):
+    for i, alpha in enumerate(ALPHA):
+        print(f'training for Germany, alpha: {alpha}, iter: {iteration}')
+
+        covid_data = np.genfromtxt(f'./datasets/I_RKI_Germany_1_{int(1/alpha)}.csv', delimiter=',')
+        dataset = PandemicDataset('Germany', 
+                                  ['I'], 
+                                  83100000, 
+                                  *covid_data, 
+                                  norm_name=NORM[i], 
+                                  C=10**6, 
+                                  use_scaled_time=True)
+        problem = ReducedSIRProblem(dataset, alpha)
+
+        dinn = DINN(2, 
+                    dataset, 
+                    [], 
+                    problem, 
+                    None, 
+                    state_variables=['R_t'], 
+                    hidden_size=100, 
+                    hidden_layers=4, 
+                    activation_layer=torch.nn.Tanh(),
+                    activation_output=Activation.POWER)
+
+        dinn.configure_training(1e-3, 
+                                25000, 
+                                scheduler_class=Scheduler.POLYNOMIAL, 
+                                lambda_obs=1e4,
+                                lambda_physics=1e-6, 
+                                verbose=True)
+        dinn.train(verbose=True, do_split_training=True, start_split=15000)
+
+        dinn.save_training_process(f'Germany_{i}_{iteration}')
+
+        r_t = dinn.get_output(1).detach().cpu().numpy()
+        with open(f'./results/Germany_{i}_{iteration}.csv', 'w', newline='') as csvfile:
+            writer = csv.writer(csvfile, delimiter=',')
+            writer.writerow(r_t)

+ 109 - 0
plot_datasets.py

@@ -0,0 +1,109 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+DS_DIR = './datasets/'
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+
+STATE_LOOKUP = {'Schleswig_Holstein' : 'Schleswig-Holstein',
+                'Hamburg' : 'Hamburg', 
+                'Niedersachsen' : 'Niedersachsen', 
+                'Bremen' : 'Bremen',
+                'Nordrhein_Westfalen' : 'North Rhine-Westphalia',
+                'Hessen' : 'Hessen',
+                'Rheinland_Pfalz' : 'Rhineland-Palatinate',
+                'Baden_Wuerttemberg' : 'Baden-Württemberg',
+                'Bayern' : 'Bavaria',
+                'Saarland' : 'Saarland',
+                'Berlin' : 'Berlin',
+                'Brandenburg' : 'Brandenburg',
+                'Mecklenburg_Vorpommern' : 'Mecklenburg-Western Pomerania',
+                'Sachsen' : 'Saxony',
+                'Sachsen_Anhalt' : 'Saxony-Anhalt',
+                'Thueringen' : 'Thuringia'}
+
+plotter = Plotter()
+
+data = []
+in_text_data = [np.genfromtxt(DS_DIR + f'SIR_data.csv', delimiter=',')[1:], 1]
+for state in STATE_LOOKUP.keys():
+    # print(f"plot {state}")
+    state_data_5 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_5.csv', delimiter=',')[1]
+    state_data_14 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_14.csv', delimiter=',')[1]
+    sir_data = np.genfromtxt(DS_DIR + f'SIR_RKI_{state}_1_14.csv', delimiter=',')[1:]
+    data.append(sir_data)
+    t = np.arange(0, 1200, 1)
+
+    if state in ["Schleswig_Holstein", "Berlin", "Thueringen"]:
+        in_text_data.append(sir_data)
+
+    plotter.plot(t, 
+                 [state_data_14, state_data_5], 
+                 [r'$\alpha=\frac{1}{14}$', r'$\alpha=\frac{1}{5}$'],
+                 f'{state}_datasets',
+                 f'{STATE_LOOKUP[state]}',
+                 (12,6),
+                 xlabel='time / days',
+                 ylabel='amount of people')
+    
+do_log=False
+plotter.cluster_plot(t,
+                     data[:6],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6,6),
+                     "state_sir_cluster_1",
+                     list(STATE_LOOKUP.values())[:6],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.05,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+plotter.cluster_plot(t,
+                     data[7:],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (3, 3),
+                     (6,6),
+                     "state_sir_cluster_2",
+                     list(STATE_LOOKUP.values())[7:],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.03,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+
+germany_data = np.genfromtxt(DS_DIR + f'SIR_RKI_Germany_1_14.csv', delimiter=',')[1:]
+in_text_data[1] = germany_data
+
+plotter.plot(t,
+             germany_data,
+             [r'$S$', r'$I$', r'$R$'],
+             'germany_single_sir',
+             'Germany',
+             (6,6),
+             plot_legend=False,
+             xlabel='time / days',
+             ylabel='amount of people',)
+
+plotter.cluster_plot(t, 
+                     in_text_data,
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6, 6),
+                     "in_text_SIR",
+                     ["synthetic SIR data", "Germany", 'Schleswig Holstein', 'Berlin', 'Thuringia'],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     legend_loc=(0.51, 0.8),
+                     add_y_space=0,
+                     same_axes=False,
+                     free_axis=(0, 1),
+                     plot_all_labels=False)
+
+
+

+ 302 - 0
plot_results.py

@@ -0,0 +1,302 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+I_PRED_SRC_DIR = SRC_DIR + 'I_predictions/'
+SIM_DIR = './visualizations/'
+
+def get_error(y, y_ref):
+    err = []
+    for i in range(len(y)):
+        diff = y[i] - y_ref
+        err.append(np.linalg.norm(diff) / np.linalg.norm(y_ref))
+    return np.array(err).mean(axis=0)
+
+STATE_LOOKUP = {'Schleswig_Holstein' : (79.5,0.0849),
+                'Hamburg' : (84.5, 0.0948),
+                'Niedersachsen' : (77.6, 0.0774),
+                'Bremen' : (88.3,0.0933),
+                'Nordrhein_Westfalen' : (79.5,0.0777),
+                'Hessen' : (75.8,0.1017),
+                'Rheinland_Pfalz' : (75.6,0.0895),
+                'Baden_Wuerttemberg' : (74.5,0.0796),
+                'Bayern' : (75.1,0.0952),
+                'Saarland' : (82.4,0.1080),
+                'Berlin' : (78.1,0.0667),
+                'Brandenburg' : (68.1,0.0724),
+                'Mecklenburg_Vorpommern' : (74.7,0.0540),
+                'Sachsen' : (65.1,0.1109),
+                'Sachsen_Anhalt' : (74.1,0.0785),
+                'Thueringen' : (70.3,0.0837),
+                'Germany' : (76.4, 0.0804)}
+
+state_names = ['Schleswig-Holstein',
+                'Hamburg', 
+                'Lower Saxony', 
+                'Bremen',
+                'North Rhine-Westphalia',
+                'Hesse',
+                'Rhineland-Palatinate',
+                'Baden-Württemberg',
+                'Bavaria',
+                'Saarland',
+                'Berlin',
+                'Brandenburg',
+                'Mecklenburg-Western Pomerania',
+                'Saxony',
+                'Saxony-Anhalt',
+                'Thuringia', 
+                'Germany']
+
+plotter = Plotter(additional_colors=['yellow', 'cyan', 'magenta', ])
+
+# plot results for alpha and beta
+
+print("Visualizing Alpha and Beta results")
+
+# synth
+param_matrix = np.genfromtxt(SRC_DIR + f'synthetic_parameters.csv', delimiter=',')
+mean = param_matrix.mean(axis=0)
+std = param_matrix.std(axis=0)
+
+print("States Table form:")
+print('{0:.4f}'.format(1/3), "&", '{0:.4f}'.format(mean[0]), "&", '{0:.4f}'.format(std[0]), "&", '{0:.4f}'.format(1/2), "&", '{0:.4f}'.format(mean[1]), "&", '{0:.4f}'.format(std[1]), "\\\ ")
+
+plotter.scatter(np.arange(1, 6, 1), [param_matrix[:,0], param_matrix[:,1]], [r"$\alpha$", r"$\beta$"], (7,3.5), 'reproducability', '', true_values=[1/3, 1/2], xlabel='iteration')
+
+vaccination_ratios = []
+mean_std_parameters = {}
+for state in STATE_LOOKUP.keys():
+    state_matrix = np.genfromtxt(SRC_DIR + f'{state}_parameters.csv', delimiter=',')
+    mean = state_matrix.mean(axis=0)
+    std = state_matrix.std(axis=0)
+    mean_std_parameters.update({state : (mean, std)})
+    vaccination_ratios.append(STATE_LOOKUP[state][0])
+
+values = np.array(list(mean_std_parameters.values()))
+means = values[:,0]
+stds = values[:,1]
+alpha_means = means[:,0]
+beta_means = means[:,1]
+alpha_stds = stds[:,0]
+beta_stds = stds[:,1]
+
+print(f"Vaccination corr: {np.corrcoef(beta_means, vaccination_ratios)[0, 1]}")
+vaccination_ratios = vaccination_ratios[:-1]
+sn = np.array(state_names[:-1]).copy()
+sn[12] = "MWP"
+plotter.scatter(sn, 
+                [alpha_means[:-1], beta_means[:-1]], 
+                [r'$\alpha$', r'$\beta$', ],
+                (12, 6), 
+                'mean_std_alpha_beta_res',
+                '',
+                std=[alpha_stds[:-1], beta_stds[:-1]],
+                true_values=[alpha_means[-1], beta_means[-1]],
+                true_label='Germany',
+                xlabel_rotation=60,
+                plot_legend=True,
+                legend_loc="lower right")
+
+print("States Table form:")
+for i, state in enumerate(STATE_LOOKUP.keys()):
+    print(state_names[i], "&", '{0:.3f}'.format(alpha_means[i]), "{\\tiny $\\pm",'{0:.3f}'.format(alpha_stds[i]), "$}", "&", '{0:.3f}'.format(beta_means[i]), "{\\tiny $\\pm", '{0:.3f}'.format(beta_stds[i]), "$}", "&", STATE_LOOKUP[state][1], "&", '{0:.3f}'.format(beta_means[i]-beta_means[16]), "&", '{0:.1f}'.format(STATE_LOOKUP[state][0]), "\\\ ")
+
+print()
+# plot results for reproduction number
+
+# synth
+synth_iterations = []
+for i in range(10):   
+    synth_iterations.append(np.genfromtxt(SRC_DIR + f'synthetic_{i}.csv', delimiter=','))
+
+synth_matrix = np.array(synth_iterations)
+t = np.arange(0, len(synth_matrix[0]), 1)
+synth_r_t = np.zeros(150, dtype=np.float64)
+for i, time in enumerate(range(150)):
+    synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
+
+print(f"Synthetic error R_t: {get_error(synth_matrix.mean(axis=0), synth_r_t)}")
+plotter.plot(t, 
+             [synth_matrix.mean(axis=0), synth_r_t],
+             [r'$\mathcal{R}_t$', r'true $\mathcal{R}_t$'], 
+             f"synthetic_R_t_statistics", 
+             r"Synthetic data $\mathcal{R}_t$", 
+             (9, 6),
+             fill_between=[synth_matrix.std(axis=0)],
+             xlabel="time / days")
+
+pred_synth = np.genfromtxt(I_PRED_SRC_DIR + f'synthetic_0_I_prediction.csv', delimiter=',')
+
+print(f"Synthetic error I: {get_error(pred_synth[2], pred_synth[1])}")
+
+plotter.plot(pred_synth[0], 
+             [pred_synth[2], pred_synth[1]],
+             [r'prediction $I$', r'true $I$'], 
+             f"synthetic_I_prediction", 
+             r"Synthetic data $I$ prediction", 
+             (9, 6),
+             xlabel="time / days",
+             ylabel='amount of people')
+
+EVENT_LOOKUP = {'start of vaccination' : 455,
+                'alpha variant' : 357,
+                'delta variant' : 473,
+                'omicron variant' : 663}
+
+
+ALPHA = [1 / 14, 1 / 5]
+
+
+cluster_counter = 1
+cluster_idx = 0
+
+in_text_r_t_mean = []
+in_text_r_t_std = []
+in_text_I = []
+in_text_I_std = []
+
+cluster_r_t_mean = []
+cluster_r_t_std = []
+cluster_I = []
+cluster_I_std = []
+cluster_states = []
+
+for k, state in enumerate(STATE_LOOKUP.keys()):
+
+    if state == "Thueringen":
+        l = 1
+    elif state == "Bremen":
+        l = 0
+    # data fetch arrays
+    r_t = []
+    pred_i = []
+    true_i = []
+    cluster_states.append(state_names[k])
+    cluster_r_t_mean.append([])
+    cluster_r_t_std.append([])
+    cluster_I.append([])
+    cluster_I_std.append([np.zeros(1200), np.zeros(1200)])
+    if state == "Thueringen" or state == "Bremen":
+        in_text_r_t_mean.append([])
+        in_text_r_t_std.append([])   
+        in_text_I.append([])
+        in_text_I_std.append([np.zeros(1200), np.zeros(1200)])
+    for i, alpha in enumerate(ALPHA):
+        iterations = []
+        predictions = []
+        true = []
+        for j in range(10): 
+            iterations.append(np.genfromtxt(SRC_DIR + f'{state}_{i}_{j}.csv', delimiter=','))
+            if (k >= 3 and j == 3) or j > 3:
+                data = np.genfromtxt(I_PRED_SRC_DIR + f'{state}_{i}_{j}_I_prediction.csv', delimiter=',')
+                predictions.append(data[2])
+                true = data[1]
+        iterations = np.array(iterations)
+        r_t.append(iterations)
+
+        predictions = np.array(predictions)
+        pred_i.append(predictions)
+        true_i.append(true)
+
+        cluster_r_t_mean[cluster_counter-1].append(iterations.mean(axis=0))
+        cluster_r_t_std[cluster_counter-1].append(iterations.std(axis=0))
+        if state == "Thueringen" or state == "Bremen":
+            in_text_r_t_mean[l].append(iterations.mean(axis=0))
+            in_text_r_t_std[l].append(iterations.std(axis=0))
+    if state == "Thueringen" or state == "Bremen":
+        in_text_I[l].append(true_i[0])
+        in_text_I[l].append(true_i[1])
+        in_text_I[l].append(pred_i[0].mean(axis=0))
+        in_text_I[l].append(pred_i[1].mean(axis=0))
+        in_text_I_std[l].append(pred_i[0].std(axis=0))
+        in_text_I_std[l].append(pred_i[1].std(axis=0))
+
+    cluster_I[cluster_counter-1].append(true_i[0])
+    cluster_I[cluster_counter-1].append(true_i[1])
+    cluster_I[cluster_counter-1].append(pred_i[0].mean(axis=0))
+    cluster_I[cluster_counter-1].append(pred_i[1].mean(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[0].std(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[1].std(axis=0))
+
+    # plot
+    print(f"{state_names[k]} & {'{0:.3f}'.format(get_error(pred_i[0], true_i[0]))} & {'{0:.3f}'.format(get_error(pred_i[1], true_i[1]))} & \phantom{{0}} & {(r_t[0] > 1).sum(axis=1).mean()} & {(r_t[1] > 1).sum(axis=1).mean()} & {'{0:.3f}'.format(r_t[0].max(axis=1).mean())} & {'{0:.3f}'.format(r_t[1].max(axis=1).mean())}\\\ ")
+    if len(cluster_states) == 4 and state != "Thueringen" or len(cluster_states) == 5 and state == "Germany":
+        t = np.arange(0, 1200, 1)
+        if len(cluster_states) == 5:
+            y_lim_exception = 4
+        else:
+            y_lim_exception = None
+        plotter.cluster_plot(t,
+                            cluster_r_t_mean,
+                            [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                            (len(cluster_states), 1),
+                            (9, 6),
+                            f'r_t_cluster_{cluster_idx}',
+                            [state + r"  $\mathcal{R}_t$" for state in cluster_states],
+                            fill_between=cluster_r_t_std,
+                            event_lookup=EVENT_LOOKUP,
+                            xlabel='time / days',
+                            ylim=(0.3, 2.0),
+                            legend_loc=(0.53, 0.992),
+                            number_of_legend_columns=3)
+        plotter.cluster_plot(t,
+                             cluster_I,
+                             [r"true $I$ $\alpha=\frac{1}{14}$", 
+                             r"true $I$ $\alpha=\frac{1}{5}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                             (len(cluster_states), 1),
+                             (9, 6),
+                             f'I_cluster_{cluster_idx}',
+                             [state + r" $I$ prediction" for state in cluster_states],
+                             fill_between=cluster_I_std,
+                             xlabel='time / days',
+                             ylabel='amount of people',
+                             same_axes=False,
+                             ylim=(0, 600000),
+                             legend_loc=(0.55, 0.992),
+                             number_of_legend_columns=2,
+                             y_lim_exception=y_lim_exception)
+        cluster_counter = 0
+        cluster_idx += 1
+        cluster_r_t_mean = []
+        cluster_r_t_std = []
+        cluster_I = []
+        cluster_I_std = []
+        cluster_states = []
+    cluster_counter += 1
+
+plotter.cluster_plot(t,
+                    in_text_r_t_mean,
+                    [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                    (2, 1),
+                    (9, 6),
+                    f'r_t_cluster_intext',
+                    [state + r"  $\mathcal{R}_t$" for state in ['Bremen', 'Thuringia']],
+                    fill_between=in_text_r_t_std,
+                    event_lookup=EVENT_LOOKUP,
+                    xlabel='time / days',
+                    ylim=(0.3, 2.0),
+                    legend_loc=(0.53, 0.999),
+                    add_y_space=0.08,
+                    number_of_legend_columns=3)
+plotter.cluster_plot(t,
+                     in_text_I,
+                     [r"true $I$ $\alpha=\frac{1}{14}$", 
+                     r"true $I$ $\alpha=\frac{1}{5}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                     (2, 1),
+                     (9, 6),
+                     f'I_cluster_intext',
+                     [state + r" $I$ prediction" for state in ['Bremen', 'Thuringia']],
+                     fill_between=in_text_I_std,
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     ylim=(0, 600000),
+                     legend_loc=(0.55, 0.999),
+                     add_y_space=0.08,
+                     number_of_legend_columns=2)
+

+ 21 - 0
plot_training_metrics.py

@@ -0,0 +1,21 @@
+import numpy as np
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+MET_DIR = 'training_metrics/'
+
+plotter = Plotter()
+# plot synthetic loss
+physics_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_physics_loss.csv', delimiter=',')
+obs_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_obs_loss.csv', delimiter=',')
+loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_loss.csv', delimiter=',')
+
+t = np.arange(0, len(loss), 1)
+plotter.plot(t, 
+             [loss, physics_loss, obs_loss], 
+             ['loss', 'physics loss', 'data loss'], 
+             'Germany_1_loss',
+             'Loss', 
+             (6, 6),
+             y_log_scale=True)

+ 55 - 8
src/dataset.py

@@ -1,4 +1,10 @@
 import torch
 import torch
+from enum import Enum
+
+class Norms(Enum):
+    POPULATION=0
+    MIN_MAX=1
+    CONSTANT=2
 
 
 class PandemicDataset:
 class PandemicDataset:
     def __init__(self, 
     def __init__(self, 
@@ -6,7 +12,10 @@ class PandemicDataset:
                  group_names:list, 
                  group_names:list, 
                  N: int, 
                  N: int, 
                  t, 
                  t, 
-                 *groups):
+                 *groups, 
+                 norm_name=Norms.MIN_MAX,
+                 C = 10**5,
+                 use_scaled_time=False):
         """Class to hold all data for one training process.
         """Class to hold all data for one training process.
 
 
         Args:
         Args:
@@ -15,19 +24,39 @@ class PandemicDataset:
             t (np.array): Array of timesteps.
             t (np.array): Array of timesteps.
             *groups (np.array): Arrays of size data for each group for each timestep.
             *groups (np.array): Arrays of size data for each group for each timestep.
         """
         """
-
         if torch.cuda.is_available():
         if torch.cuda.is_available():
             self.device_name = 'cuda'
             self.device_name = 'cuda'
         else:
         else:
             self.device_name = 'cpu'
             self.device_name = 'cpu'
 
 
+        match norm_name:
+            case Norms.POPULATION:
+                self.__norm = self.__population_norm
+                self.__denorm = self.__population_denorm
+            case Norms.MIN_MAX:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+            case Norms.CONSTANT:
+                self.__norm = self.__constant_norm
+                self.__denorm = self.__constant_denorm
+            case _:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+
         self.name = name
         self.name = name
         self.N = N
         self.N = N
         self.t_init = t.min()
         self.t_init = t.min()
         self.t_final = t.max()
         self.t_final = t.max()
+        self.C = C
 
 
         self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
         self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
-        self.t_batch = self.t_raw.view(-1, 1).float()
+
+        self.t_scaled = ((self.t_raw - self.t_init) / (self.t_final - self.t_init)).detach().requires_grad_()
+        self.use_scaled_time = use_scaled_time
+        if use_scaled_time:
+            self.t_batch = self.t_scaled.view(-1, 1).float()
+        else:
+            self.t_batch = self.t_raw.view(-1, 1).float()
 
 
         self.__group_dict = {}
         self.__group_dict = {}
         for i, name in enumerate(group_names):
         for i, name in enumerate(group_names):
@@ -39,7 +68,7 @@ class PandemicDataset:
         
         
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
-        self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
+        self.__norms = self.__norm(self.__groups)
 
 
     @property
     @property
     def number_groups(self):
     def number_groups(self):
@@ -52,14 +81,32 @@ class PandemicDataset:
     @property
     @property
     def group_names(self):
     def group_names(self):
         return self.__group_names
         return self.__group_names
+    
+    def __population_norm(self, data):
+        return [(data[i] / self.N) for i in range(self.number_groups)]
+    
+    def __population_denorm(self, data):
+        return [(data[i] * self.N) for i in range(self.number_groups)]
+
+    def __min_max_norm(self, data):
+        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+    
+    def __min_max_denorm(self, data):
+        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
+    
+    def __constant_norm(self, data):
+        return [(data[i] / self.C) for i in range(self.number_groups)]
+
+    def __constant_denorm(self, data):
+        return [(data[i] * self.C) for i in range(self.number_groups)]
 
 
     def get_normalized_data(self, data:list):
     def get_normalized_data(self, data:list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+        return self.__norm(data)
     
     
-    def get_denormalized_data(self, normalized_data:list):
-        assert len(normalized_data) == self.number_groups, f'normalized_data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * normalized_data[i]) for i in range(self.number_groups)]
+    def get_denormalized_data(self, data:list):
+        assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
+        return self.__denorm(data)
 
 
     def get_group(self, name:str):
     def get_group(self, name:str):
         return self.__groups[self.__group_dict[name]]
         return self.__groups[self.__group_dict[name]]

+ 131 - 31
src/dinn.py

@@ -1,10 +1,33 @@
 import torch
 import torch
+import csv
 import numpy as np
 import numpy as np
 
 
+from enum import Enum
+
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .problem import PandemicProblem
 from .plotter import Plotter
 from .plotter import Plotter
 
 
+class Optimizer(Enum):
+    ADAM=0
+
+class Scheduler(Enum):
+    CYCLIC=0
+    CONSTANT=1
+    LINEAR=2
+    POLYNOMIAL=3
+
+class Activation(Enum):
+    LINEAR=0
+    POWER=1
+
+def linear(x):
+    return x
+        
+def power(x):
+    return torch.float_power(x, 2)
+
+
 class DINN:
 class DINN:
     class NN(torch.nn.Module):
     class NN(torch.nn.Module):
         def __init__(self, 
         def __init__(self, 
@@ -12,9 +35,12 @@ class DINN:
                      input_size: int,
                      input_size: int,
                      hidden_size: int,
                      hidden_size: int,
                      hidden_layers: int, 
                      hidden_layers: int, 
-                     activation_layer, 
+                     activation_layer,
                      t_init,
                      t_init,
-                     t_final) -> None:
+                     t_final,
+                     output_activation_function=Activation.LINEAR,
+                     use_glorot_initialization = False,
+                     use_t_scaled=True) -> None:
             """Neural Network
             """Neural Network
 
 
             Args:
             Args:
@@ -26,21 +52,39 @@ class DINN:
             """
             """
             super(DINN.NN, self).__init__()
             super(DINN.NN, self).__init__()
 
 
+            if output_activation_function == Activation.LINEAR:
+                self.out_activation = linear
+            elif output_activation_function == Activation.POWER:
+                self.out_activation = power
+            else:
+                print('Set output activation to default: linear')
+                self.out_activation = self.linear
+
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
             self.output = torch.nn.Linear(hidden_size, output_size)
+
+            if use_glorot_initialization:
+                torch.nn.init.xavier_uniform_(self.input[0].weight)
+                for i in range(hidden_layers):
+                    torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
+                torch.nn.init.xavier_uniform_(self.output.weight)
             
             
             self.__t_init = t_init
             self.__t_init = t_init
             self.__t_final = t_final
             self.__t_final = t_final
+            self.__use_t_scaled = use_t_scaled
 
 
         def forward(self, t):
         def forward(self, t):
             # normalize input
             # normalize input
-            t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
-            x = self.input(t_scaled)
+            if self.__use_t_scaled:
+                t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
+            else:
+                t_forward = t
+            x = self.input(t_forward)
             x = self.hidden(x)
             x = self.hidden(x)
             x = self.output(x)
             x = self.output(x)
-            return x
-
+            return self.out_activation(x)
+    
     def __init__(self,
     def __init__(self,
                  output_size: int,
                  output_size: int,
                  data: PandemicDataset,
                  data: PandemicDataset,
@@ -52,7 +96,9 @@ class DINN:
                  input_size=1, 
                  input_size=1, 
                  hidden_size=20, 
                  hidden_size=20, 
                  hidden_layers=7, 
                  hidden_layers=7, 
-                 activation_layer=torch.nn.ReLU()) -> None:
+                 activation_layer=torch.nn.ReLU(),
+                 activation_output=Activation.LINEAR,
+                 use_glorot_initialization = False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
         parameters of a specific mathematical model.
 
 
@@ -78,9 +124,12 @@ class DINN:
                              input_size, 
                              input_size, 
                              hidden_size, 
                              hidden_size, 
                              hidden_layers, 
                              hidden_layers, 
-                             activation_layer, 
+                             activation_layer,
                              data.t_init, 
                              data.t_init, 
-                             data.t_final)
+                             data.t_final,
+                             activation_output,
+                             use_glorot_initialization=use_glorot_initialization,
+                             use_t_scaled=data.use_scaled_time)
         self.model = self.model.to(self.device)
         self.model = self.model.to(self.device)
         self.data = data
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.parameter_regulator = parameter_regulator
@@ -131,8 +180,21 @@ class DINN:
             list: list of regulated parameters
             list: list of regulated parameters
         """
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
+
     
     
-    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+    def get_output(self, index):
+        output = self.model(self.data.t_batch)
+        return output[:, index]
+    
+    def configure_training(self, 
+                           lr:float, 
+                           epochs:int, 
+                           optimizer_class=Optimizer.ADAM, 
+                           scheduler_class=Scheduler.CYCLIC, 
+                           scheduler_factor = 1, 
+                           lambda_obs = 1,
+                           lambda_physics = 1,
+                           verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
 
         Args:
         Args:
@@ -144,34 +206,38 @@ class DINN:
         """
         """
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         self.epochs = epochs
         self.epochs = epochs
-        match optimizer_name:
-            case 'Adam':
+        self.lambda_obs = lambda_obs
+        self.lambda_physics = lambda_physics
+        match optimizer_class:
+            case Optimizer.ADAM:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
             case _:
             case _:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 if verbose:
                 if verbose:
                     print('---------------------------------')
                     print('---------------------------------')
-                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
                     print('---------------------------------')
                     print('---------------------------------')
-                optimizer_name = 'Adam'
+                optimizer_class = Optimizer.ADAM
 
 
-        match scheduler_name:
-            case 'CyclicLR':
+        match scheduler_class:
+            case Scheduler.CYCLIC:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-            case 'LinearLR':
+            case Scheduler.CONSTANT:
+                self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
+            case Scheduler.LINEAR:
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
-            case 'PolynomialLR':
+            case Scheduler.POLYNOMIAL:
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
             case _:
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                 if verbose:
                     print('---------------------------------')
                     print('---------------------------------')
-                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
                     print('---------------------------------')
                     print('---------------------------------')
-                scheduler_name = 'CyclicLR'
+                scheduler_class = Scheduler.CYCLIC
 
 
         if verbose:
         if verbose:
-            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
 
 
         self.__is_configured = True
         self.__is_configured = True
 
 
@@ -179,7 +245,9 @@ class DINN:
     def train(self, 
     def train(self, 
               create_animation=False,
               create_animation=False,
               animation_sample_rate=500,
               animation_sample_rate=500,
-              verbose=False):
+              verbose=False,
+              do_split_training=False,
+              start_split=10000):
         """Training routine for the DINN.
         """Training routine for the DINN.
 
 
         Args:
         Args:
@@ -201,20 +269,27 @@ class DINN:
             # get the prediction and the fitting residuals
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
-
             self.optimizer.zero_grad()
             self.optimizer.zero_grad()
 
 
             # calculate loss from the differential system
             # calculate loss from the differential system
             loss_physics = 0
             loss_physics = 0
             for residual in residuals:
             for residual in residuals:
                 loss_physics += torch.mean(torch.square(residual))
                 loss_physics += torch.mean(torch.square(residual))
+            loss_physics *= self.lambda_physics
 
 
             # calculate loss from the dataset
             # calculate loss from the dataset
             loss_obs = 0
             loss_obs = 0
             for i, group in enumerate(self.data.group_names):
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+            loss_obs *= self.lambda_obs
             
             
-            loss = loss_obs + loss_physics
+            if do_split_training:
+                if epoch < start_split:
+                    loss = loss_obs
+                else:
+                    loss = loss_obs + loss_physics
+            else:
+                loss = loss_obs + loss_physics
 
 
             loss.backward()
             loss.backward()
             self.optimizer.step()
             self.optimizer.step()
@@ -289,7 +364,8 @@ class DINN:
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    ['prediction', 'ground truth'], 
                                    ['prediction', 'ground truth'], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   list(self.parameters_tilda.items())[i][0], 
+                                   (6,6), 
                                    is_background=[0, 1], 
                                    is_background=[0, 1], 
                                    xlabel='epochs')
                                    xlabel='epochs')
             else:
             else:
@@ -300,18 +376,42 @@ class DINN:
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   xlabel='epochs', 
                                   xlabel='epochs', 
                                   plot_legend=False)
                                   plot_legend=False)
+                
+    def save_training_process(self, title, save_predictions = True):
+        losses = {'loss' : self.losses,
+                  'obs_loss' : self.obs_losses,
+                  'physics_loss' : self.physics_losses}
+        for loss in losses.keys():
+            with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(losses[loss])
+
+        for i, parameter in enumerate(self.parameters):
+            with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(parameter)
+        if save_predictions:
+            prediction = self.model(self.data.t_batch)
+            for i, group in enumerate(self.data.group_names):
+                t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
+                true = self.data.get_group(group).detach().cpu().numpy()
+                pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
+                print(t.shape, true.shape)
+                with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(t)
+                    writer.writerow(true)
+                    writer.writerow(pred)
 
 
     def plot_state_variables(self):
     def plot_state_variables(self):
+        prediction = self.model(self.data.t_batch)
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
-            prediction = self.model(self.data.t_batch)
-            groups = [prediction[:, i] for i in range(self.data.number_groups)]
-            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
             self.plotter.plot(t,
-                              [prediction[:, i]] + groups,
-                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              [prediction[:, i]],
+                              [self.__state_variables[i-self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               self.__state_variables[i-self.data.number_groups],
                               self.__state_variables[i-self.data.number_groups],
-                              is_background=[0, 1, 1],
                               figure_shape=(12, 6),
                               figure_shape=(12, 6),
                               plot_legend=True,
                               plot_legend=True,
                               xlabel='time / days')
                               xlabel='time / days')

+ 224 - 17
src/plotter.py

@@ -1,9 +1,12 @@
 import os
 import os
 import torch
 import torch
 import imageio
 import imageio
+import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
+import matplotlib.ticker as ticker 
 
 
 from matplotlib import rcParams
 from matplotlib import rcParams
+from itertools import cycle
 
 
 FRAME_DIR = 'visualizations/temp/'
 FRAME_DIR = 'visualizations/temp/'
 VISUALISATION_DIR = 'visualizations/'
 VISUALISATION_DIR = 'visualizations/'
@@ -13,7 +16,7 @@ INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 REMOVED_COLOR = '#83eb5e'
 
 
 class Plotter:
 class Plotter:
-    def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
+    def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
         """Plotter of scientific plots and animations, for dinn.py.
 
 
         Args:
         Args:
@@ -24,14 +27,26 @@ class Plotter:
         """
         """
         self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
         self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
 
 
+        self.__lines_styles = ['solid', 'dotted', 'dashdot', (5, (10, 3)), (0, (5, 1)), (0, (3, 5, 1, 5)), (0, (3, 1, 1, 1)), (0, (3, 5, 1, 5, 1, 5)), (0, (3, 1, 1, 1, 1, 1))]
+
+        self.__marker_styles = ['o', '^', 's']
+
+        rcParams['text.usetex'] = True
+        rcParams['text.color'] = font_color
+
         rcParams['font.family'] = font
         rcParams['font.family'] = font
         rcParams['font.size'] = font_size
         rcParams['font.size'] = font_size
 
 
-        rcParams['text.color'] = font_color
+        rcParams['pgf.texsystem'] = 'pdflatex'
+        rcParams['pgf.rcfonts'] = False
+
         rcParams['axes.labelcolor'] = font_color
         rcParams['axes.labelcolor'] = font_color
         rcParams['xtick.color'] = font_color
         rcParams['xtick.color'] = font_color
         rcParams['ytick.color'] = font_color
         rcParams['ytick.color'] = font_color
 
 
+        plt.rc("text", usetex=True)
+        plt.rc("text.latex", preamble=r"\usepackage{amssymb} \usepackage{wasysym}")
+
         self.__frames = []
         self.__frames = []
 
 
     def __generate_figure(self, shape=(4, 4)):
     def __generate_figure(self, shape=(4, 4)):
@@ -43,7 +58,7 @@ class Plotter:
         Returns:
         Returns:
             Figure: plt.Figure that was generated.
             Figure: plt.Figure that was generated.
         """
         """
-        fig = plt.figure(figsize=shape)
+        fig = plt.figure(figsize=shape, constrained_layout=True)
         return fig
         return fig
 
 
     def reset_animation(self):
     def reset_animation(self):
@@ -51,6 +66,7 @@ class Plotter:
         """
         """
         self.__frames = []
         self.__frames = []
     
     
+    #TODO comments
     def plot(self, 
     def plot(self, 
              x, 
              x, 
              y:list, 
              y:list, 
@@ -58,16 +74,19 @@ class Plotter:
              file_name:str,
              file_name:str,
              title:str, 
              title:str, 
              figure_shape:tuple, 
              figure_shape:tuple, 
+             event_lookup={},
              is_frame=False, 
              is_frame=False, 
              is_background=[], 
              is_background=[], 
+             fill_between=[],
              plot_legend=True, 
              plot_legend=True, 
              y_log_scale=False, 
              y_log_scale=False, 
              lw=3, 
              lw=3, 
              legend_loc='best', 
              legend_loc='best', 
              ylim=(None, None),
              ylim=(None, None),
+             number_xlabels = 5,
              xlabel='',
              xlabel='',
              ylabel='', 
              ylabel='', 
-             xlabel_rotation=0):
+             xlabel_rotation=None):
         """Plotting method.
         """Plotting method.
 
 
         Args:
         Args:
@@ -87,20 +106,26 @@ class Plotter:
             xlabel (str, optional): Label for the x axis. Defaults to ''.
             xlabel (str, optional): Label for the x axis. Defaults to ''.
             ylabel (str, optional): Label for the y axis. Defaults to ''.
             ylabel (str, optional): Label for the y axis. Defaults to ''.
         """
         """
-        assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
         assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
         fig = self.__generate_figure(shape=figure_shape)
         fig = self.__generate_figure(shape=figure_shape)
 
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
         ax.set_facecolor('xkcd:white')
-        ax.yaxis.set_tick_params(length=0, which='both')
-        ax.xaxis.set_tick_params(length=0, which='both')
-        ax.grid(which='major', c='black', lw=0.2, ls='-')
+        #ax.yaxis.set_tick_params(length=0, which='both')
+        #ax.xaxis.set_tick_params(length=0, which='both')
+
+        #ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
         ax.set_title(title)
 
 
-        for spine in ('top', 'right', 'bottom', 'left'):
-            ax.spines[spine].set_visible(False)
+        #for spine in ('top', 'right', 'bottom', 'left'):
+         #   ax.spines[spine].set_visible(False)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
 
 
+        linecycler = cycle(self.__lines_styles)
+        j = 0
         for i, array in enumerate(y):
         for i, array in enumerate(y):
             alpha = 1
             alpha = 1
             if len(is_background) != 0:
             if len(is_background) != 0:
@@ -111,12 +136,19 @@ class Plotter:
                 data = array.cpu().detach().numpy()
                 data = array.cpu().detach().numpy()
             else:
             else:
                 data = array
                 data = array
-            
-            if len(is_background) != 0 and alpha == 1:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
-            else:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
 
 
+            space = int(len(x) / number_xlabels)
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
+        
+            ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
+            if i < len(fill_between):
+                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+            j = i
+
+        for event in event_lookup.keys():
+            j += 1
+            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
+            
         if plot_legend:
         if plot_legend:
             plt.legend(loc=legend_loc)
             plt.legend(loc=legend_loc)
 
 
@@ -132,7 +164,8 @@ class Plotter:
         if ylabel != '':
         if ylabel != '':
             plt.ylabel(ylabel)
             plt.ylabel(ylabel)
 
 
-        plt.xticks(rotation=xlabel_rotation)
+        if xlabel_rotation != None:
+            plt.xticks(rotation=xlabel_rotation)
  
  
         if not os.path.exists(FRAME_DIR):
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
             os.makedirs(FRAME_DIR)
@@ -144,7 +177,181 @@ class Plotter:
             self.__frames.append(imageio.imread(frame_path))
             self.__frames.append(imageio.imread(frame_path))
             os.remove(frame_path)
             os.remove(frame_path)
         else:
         else:
-            plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
+            plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
+
+    def cluster_plot(self, 
+                     x, 
+                     y:list, 
+                     labels, 
+                     shape, 
+                     plots_shape, 
+                     file_name:str, 
+                     titles:list, 
+                     number_xlabels=5,
+                     lw=3, 
+                     fill_between=[],
+                     event_lookup={},
+                     xlabel='',
+                     ylabel='',
+                     ylim=(None, None),
+                     y_lim_exception=None,
+                     y_log_scale=False, 
+                     legend_loc=(0.5,0.992),
+                     add_y_space=0.05,
+                     number_of_legend_columns=1,
+                     same_axes=True,
+                     free_axis=(None, None),
+                     plot_all_labels = True):
+        real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
+        fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
+        plot_idx = 0
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        if 1 in shape:
+            for i in range(len(axes)):
+                linecycler = cycle(self.__lines_styles)
+                colorcycler = cycle(self.__colors)
+                for j, array in enumerate(y[i]):
+                    color = next(colorcycler)
+                    if torch.is_tensor(array):
+                        data = array.cpu().detach().numpy()
+                    else:
+                        data = array
+
+                    space = int(len(x) / number_xlabels)
+                    axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                    axes[i].plot(x, 
+                                 data, 
+                                 linestyle=next(linecycler),
+                                 label=labels[j],
+                                 c=color,
+                                 lw=lw)
+                    axes[i].set_title(titles[i])
+                    if j < len(fill_between[i]):
+                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                for event in event_lookup.keys():
+                    axes[i].axvline(x=event_lookup[event], 
+                                    color=next(colorcycler), 
+                                    label=event, 
+                                    ls=next(linecycler),
+                                    lw=lw)
+            
+                if ylim[0] != None and y_lim_exception != i:
+                    axes[i].set_ylim(ylim)
+                
+                if y_log_scale:
+                    plt.yscale('log')
+        else:
+            for i in range(shape[0]):
+                for j in range(shape[1]):
+                    if (i, j) == free_axis:
+                        axes[i, j].axis('off')
+                    else:
+                        linecycler = cycle(self.__lines_styles)
+                        colorcycler = cycle(self.__colors)
+                        if plot_idx < len(y):
+                            for k, array in enumerate(y[plot_idx]):
+                                if torch.is_tensor(array):
+                                    data = array.cpu().detach().numpy()
+                                else:
+                                    data = array
+                                space = int(len(x) / number_xlabels)
+                                axes[i, j].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                                if len(x) > len(data):
+                                    c = len(data)
+                                else:
+                                    c = len(x)
+                                axes[i, j].plot(x[:c], 
+                                                data, 
+                                                label=labels[k], 
+                                                c=next(colorcycler), 
+                                                lw=lw, 
+                                                linestyle=next(linecycler))
+                            axes[i, j].set_title(titles[plot_idx])
+                            if ylim[0] != None:
+                                axes[i, j].set_ylim(ylim)
+                        plot_idx += 1
+                    if y_log_scale:
+                        plt.yscale('log')
+        if 1 in shape:
+            lines, labels = axes[0].get_legend_handles_labels()
+        else:
+            lines, labels = axes[0, 0].get_legend_handles_labels()
+        fig.legend(lines, labels, loc='upper center', ncol=number_of_legend_columns, bbox_to_anchor=legend_loc)
+
+        for ax in axes.flat:
+            ax.set(xlabel=xlabel, ylabel=ylabel)
+
+        # Hide x labels and tick labels for top plots and y ticks for right plots.
+        if plot_all_labels:
+            for ax in axes.flat:
+                ax.label_outer()
+
+        # Adjust layout to prevent overlap
+        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+    def scatter(self, 
+                x, 
+                y:list, 
+                labels:list, 
+                figure_shape:tuple, 
+                file_name:str, 
+                title:str, 
+                std=[], 
+                true_values=[],
+                true_label='true',
+                plot_legend=True, 
+                legend_loc='best',
+                xlabel='',
+                ylabel='', 
+                xlabel_rotation=None):
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
+        fig = self.__generate_figure(shape=figure_shape)
+
+        ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+        ax.set_facecolor('xkcd:white')
+        ax.set_title(title)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        markercycler = cycle(self.__marker_styles)
+        for i, array in enumerate(y):
+            
+            if torch.is_tensor(array):
+                data = array.cpu().detach().numpy()
+            else:
+                data = array
+
+            if i >= len(std):
+                ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+            if i < len(std):
+                ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+        
+        linecycler = cycle(self.__lines_styles)
+        for i, true_value in enumerate(true_values):
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+            
+        if plot_legend:
+            plt.legend(loc=legend_loc)
+
+        if xlabel != '':
+            plt.xlabel(xlabel, )
+
+        if ylabel != '':
+            plt.ylabel(ylabel)
+
+        if xlabel_rotation != None:
+            plt.xticks(rotation=45, ha='right')
+
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
 
 
     def animate(self, name: str):
     def animate(self, name: str):
         """Builds animation from images saved in self.frames. Then saves animation as gif.
         """Builds animation from images saved in self.frames. Then saves animation as gif.

+ 57 - 43
src/preprocessing/synthetic_data.py

@@ -29,7 +29,7 @@ class SyntheticDeseaseData:
         """
         """
         self.generated = True
         self.generated = True
 
 
-    def plot(self, labels: tuple, title:str):
+    def plot(self, labels: tuple, title:str, file_name:str):
         """Plot the data which was generated.
         """Plot the data which was generated.
 
 
         Args:
         Args:
@@ -38,31 +38,31 @@ class SyntheticDeseaseData:
         """
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         if self.generated:
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
         else: 
         else: 
-            print('Data has to be generated before plotting!') # Fabienne war hier
+            print('Data has to be generated before plotting!')
 
 
-class SI(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
+class SIR(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+        """This class is able to generate synthetic data for the SIR model.
 
 
         Args:
         Args:
             plotter (Plotter): Plotter object to plot dataset curves.
             plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
+            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
         """
-
         self.N = N
         self.N = N
-        self.S_0 = N - I_0
+        self.S_0 = N - I_0 - R_0
         self.I_0 = I_0
         self.I_0 = I_0
+        self.R_0 = R_0
 
 
         self.alpha = alpha
         self.alpha = alpha
         self.beta = beta
         self.beta = beta
-
         super().__init__(simulation_time, time_points, plotter)
         super().__init__(simulation_time, time_points, plotter)
 
 
     def differential_eq(self, y, t, alpha, beta):
     def differential_eq(self, y, t, alpha, beta):
@@ -77,93 +77,107 @@ class SI(SyntheticDeseaseData):
         Returns:
         Returns:
             tuple: Change amount for each group.
             tuple: Change amount for each group.
         """
         """
-        S, I = y
-        dSdt = -self.beta * ((S * I) / self.N)
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I
-        return dSdt, dIdt
-    
+        S, I, _ = y
+        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dRdt = self.alpha * I
+        return dSdt, dIdt, dRdt
+
     def generate(self):
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """This funtion generates the data for this configuration of the SIR model.
         """
         """
-        y_0 = self.S_0, self.I_0
+        y_0 = self.S_0, self.I_0, self.R_0
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
         super().generate()
 
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name='SIR_plot'):
         """Plot the data which was generated.
         """Plot the data which was generated.
         """
         """
-        super().plot(('Susceptible', 'Infectious'), title=title)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
             COVID_Data = np.asarray([self.t, *self.data]) 
             COVID_Data = np.asarray([self.t, *self.data]) 
 
 
-            np.savetxt('datasets/SI_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
         else: 
         else: 
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
-
-class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data for the SIR model.
+class I(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
+        """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
 
         Args:
         Args:
             plotter (Plotter): Plotter object to plot dataset curves.
             plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
-            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
         """
         self.N = N
         self.N = N
-        self.S_0 = N - I_0 - R_0
+        self.C = C
         self.I_0 = I_0
         self.I_0 = I_0
-        self.R_0 = R_0
-
+ 
         self.alpha = alpha
         self.alpha = alpha
-        self.beta = beta
 
 
-        super().__init__(simulation_time, time_points, plotter)
+        self.t = np.linspace(0, 1, time_points)
+        self.t_save = np.linspace(1, time_points, time_points)
+        self.t_f = time_points
+        self.reproduction_value = []
+        self.data = None
+        self.generated = False
+        self.plotter = plotter
+        
+    def R_t(self, t):
+        descaled_t = t * self.t_f
+        # if descaled_t < threshold1:
+        return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
 
-    def differential_eq(self, y, t, alpha, beta):
+
+            
+    def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
         """In this function implements the differential equation of the SIR model will be implemented.
 
 
         Args:
         Args:
             y (tuple): Vector that holds the current state of the three groups.
             y (tuple): Vector that holds the current state of the three groups.
             t (_): not used
             t (_): not used
-            alpha (_): not used
-            beta (_): not used
 
 
         Returns:
         Returns:
             tuple: Change amount for each group.
             tuple: Change amount for each group.
         """
         """
-        S, I, R = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
-        dRdt = self.alpha * I
-        return dSdt, dIdt, dRdt
+        dIdt = self.alpha * self.t_f * (self.R_t(t) - 1) * I
+        return dIdt
 
 
     def generate(self):
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """This funtion generates the data for this configuration of the SIR model.
         """
         """
-        y_0 = self.S_0, self.I_0, self.R_0
-        self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
-        super().generate()
+        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = self.data[0] * self.C
+        self.t_counter = 0
+        self.generated =True
 
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name=''):
         """Plot the data which was generated.
         """Plot the data which was generated.
         """
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
+        if self.generated:
+            t = np.linspace(0, len(self.t), len(self.t))
+            self.plotter.plot(t, [self.data], ['Infectious'], file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            for time in self.t:
+                self.reproduction_value.append(self.R_t(time))
+            self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
+        else: 
+            print('Data has to be generated before plotting!')
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data]) 
 
 
-            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
         else: 
         else: 
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
+
         
         
 
 
 class SIDR(SyntheticDeseaseData):
 class SIDR(SyntheticDeseaseData):

+ 88 - 138
src/preprocessing/transform_data.py

@@ -3,7 +3,24 @@ import pandas as pd
 
 
 from src.plotter import Plotter
 from src.plotter import Plotter
 
 
-def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-Todesfaelle_in_Deutschland/', plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein' : (1, 2897000),
+                'Hamburg' : (2, 1841000), 
+                'Niedersachsen' : (3, 7982000), 
+                'Bremen' : (4, 569352),
+                'Nordrhein-Westfalen' : (5, 17930000),
+                'Hessen' : (6, 6266000),
+                'Rheinland-Pfalz' : (7, 4085000),
+                'Baden-Württemberg' : (8, 11070000),
+                'Bayern' : (9, 13080000),
+                'Saarland' : (10, 990509),
+                'Berlin' : (11, 3645000),
+                'Brandenburg' : (12, 2641000),
+                'Mecklenburg-Vorpommern' : (13, 1610000),
+                'Sachsen' : (14, 4078000),
+                'Sachsen-Anhalt' : (15, 2208000),
+                'Thüringen' : (16, 2143000)}
+
+def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
 
     Args:
     Args:
@@ -18,147 +35,80 @@ def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-To
         plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
         plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     """
     # read the data
     # read the data
-    df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
-
-    df = df.drop(df.index[1200:])
-    
-    # population of germany at the end of 2019
-    N = 83100000
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
-
-    # S_0 = N - I_0
-    S[0] = N - df['Faelle_gesamt'][0]
-    # I_0 = overall cases at the day - overall death cases at the day
-    I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
-    # R_0 = overall death cases at the day
-    R[0] = df['Todesfaelle_gesamt'][0]
-
-    # the recovery time is 14 days
-    recovery_queue = np.zeros(14)
-    
-    for day in range(1, df.shape[0]):
-        infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
-        deaths = df['Todesfaelle_neu'][day]
-        recoveries = recovery_queue[0]
-
-        S[day] = S[day-1] - infections
-        I[day] = I[day-1] + infections - deaths - recoveries
-        R[day] = R[day-1] + deaths + recoveries
-
-        # update recovery queue
-        if I[day] < 0:
-            recovery_queue[-1] -= I[day] 
-            I[day] = 0
-
-        recovery_queue[:-1] = recovery_queue[1:]
-        recovery_queue[-1] = infections
-
-    t = np.arange(0, df.shape[0], 1)
-    if plotter != None:
-        # plot graphs
-        plots = []
-        labels = []
-
-        if 'S' not in exclude:
-            plots.append(S)
-            labels.append('S')
-        
-        if 'I' not in exclude:
-            plots.append(I)
-            labels.append('I')
-
-        if 'R' not in exclude:
-            plots.append(R)
-            labels.append('R')
-
-        plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
-
-    COVID_Data = np.asarray([t[0::sample_rate], 
-                             S[0::sample_rate], 
-                             I[0::sample_rate], 
-                             R[0::sample_rate]]) 
-
-    np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")
-
 
 
 
 
-def get_state_cases(county_id, state_id):
-    id = county_id // 1000
-    return id == state_id
-
-def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, time_range=1200, sample_rate=1, dataset_path='datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'):
-    """Transforms the RKI infection cases dataset to a SIR dataset.
-
-    Args:
-        plotter (Plotter): Plotter object to plot dataset curves.
-        state_name (str): Name of the state that is to be singled out in the new dataset.
-        time_range (int, optional): Number of days that will be looked at in the new dataset. Defaults to 1200.
-        sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
-        dataset_path (str, optional): Path to the CSV file, where the data is stored. Defaults to 'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'.
-    """
-    df = pd.read_csv(dataset_path)
-
-    state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                    'Hamburg' : (2, 1841000), 
-                    'Niedersachsen' : (3, 7982000), 
-                    'Bremen' : (4, 569352),
-                    'Nordrhein-Westfalen' : (5, 17930000),
-                    'Hessen' : (6, 6266000),
-                    'Rheinland-Pfalz' : (7, 4085000),
-                    'Baden-Württemberg' : (8, 11070000),
-                    'Bayern' : (9, 13080000),
-                    'Saarland' : (10, 990509),
-                    'Berlin' : (11, 3645000),
-                    'Brandenburg' : (12, 2641000),
-                    'Mecklenburg-Vorpommern' : (13, 1610000),
-                    'Sachsen' : (14, 4078000),
-                    'Sachsen-Anhalt' : (15, 2208000),
-                    'Thüringen' : (16, 2143000)}
-    state_ID, N = state_lookup[state_name]
-
-    # single out a state
-    state_IDs = df['IdLandkreis'] // 1000
-    state_df = df.loc[state_IDs == state_ID]
-
-    # sort entries by state
-    state_df = state_df.sort_values('Refdatum')
-    state_df = state_df.reset_index(drop=True)
-
-
-    # collect cases    
     infections = np.zeros(time_range)
     infections = np.zeros(time_range)
-    dead = np.zeros(time_range)
-    recovered = np.zeros(time_range)
-    entry_idx = 0
-    day = 0
-    date = state_df['Refdatum'][entry_idx]
-    # check for each date all entries
-    while day < time_range:
-        # use the date sorted characteristic and take all entries with current date
-        while state_df['Refdatum'][entry_idx] == date:
-            # TODO use further parameters
-            infections[day] += state_df['AnzahlFall'][entry_idx]
-            dead[day] += state_df['AnzahlTodesfall'][entry_idx]
-            recovered[day] += state_df['AnzahlGenesen'][entry_idx]
-            entry_idx += 1
-        # move day index by difference between the current and next date
-        day += (pd.to_datetime(state_df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
-        date = state_df['Refdatum'][entry_idx]
-
-    S = np.zeros(time_range)
-    I = np.zeros(time_range)
-    R = np.zeros(time_range)
-
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    if state_name == 'Germany':
+        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        N = 83100000
+        infections[0] = df['Faelle_gesamt'][0]
+        deaths[0] = df['Todesfaelle_neu'][0]
+
+        recovery_queue = np.zeros(14)
+        for i in range(1, time_range):
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            deaths[i] = df['Todesfaelle_neu'][i]
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+    else:
+        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        state_ID, N = state_lookup[state_name]
+
+        # single out a state
+        state_IDs = df['IdLandkreis'] // 1000
+        df = df.loc[state_IDs == state_ID]
+
+        # sort entries by state
+        df = df.sort_values('Refdatum')
+        df = df.reset_index(drop=True)
+
+        # collect cases    
+        entry_idx = 0
+        day = 0
+        date = df['Refdatum'][entry_idx]
+        # check for each date all entries
+        while day < time_range:
+            # use the date sorted characteristic and take all entries with current date
+            while df['Refdatum'][entry_idx] == date:
+                infections[day] += df['AnzahlFall'][entry_idx]
+                deaths[day] += df['AnzahlTodesfall'][entry_idx]
+                entry_idx += 1
+            # move day index by difference between the current and next date
+            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            date = df['Refdatum'][entry_idx]
+
+        recovery_queue = np.zeros(14)
+        week_counter = 2
+        for i in range(1, time_range):
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+            week_counter -= 1
+        
+    df = df.drop(df.index[time_range:])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     # generate groups
     S[0] = N - infections[0]
     S[0] = N - infections[0]
     I[0] = infections[0]
     I[0] = infections[0]
     R[0] = 0
     R[0] = 0
-
-    for day in range(1, time_range):
-        S[day] = S[day-1] - infections[day]
-        I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-        R[day] = R[day-1] + I[day-1] * alpha
-
+    if model == 'I':
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
+            R[day] = R[day-1] + I[day-1] * alpha
+    else:
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
+            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            if I[day] < 0:
+                I[day] = 0
+    
     t = np.arange(0, time_range, 1)
     t = np.arange(0, time_range, 1)
 
 
     # select, which group is to be outputted
     # select, which group is to be outputted
@@ -175,12 +125,12 @@ def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, t
     plotter.plot(t, 
     plotter.plot(t, 
                  groups, 
                  groups, 
                  [*model], 
                  [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue'), 
-                 state_name +' SI', 
+                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
+                 state_name, 
                  (6,6), 
                  (6,6), 
                  xlabel='time / days', 
                  xlabel='time / days', 
                  ylabel='amount of people')
                  ylabel='amount of people')
 
 
     COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
     COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
 
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")

+ 9 - 15
src/problem.py

@@ -58,24 +58,18 @@ class ReducedSIRProblem(PandemicProblem):
         super().__init__(data)
         super().__init__(data)
         self.alpha = alpha
         self.alpha = alpha
 
 
-    def residual(self, SI_pred):
+    def residual(self, I_pred):
         super().residual()
         super().residual()
-        SI_pred.backward(self._gradients[0], retain_graph=True)
-        dSdt = self._data.t_raw.grad.clone()
-        self._data.t_raw.grad.zero_()
-
-        SI_pred.backward(self._gradients[1], retain_graph=True)
-        dIdt = self._data.t_raw.grad.clone()
-        self._data.t_raw.grad.zero_()
 
 
-        _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]])
-        R_t = SI_pred[:, 2]
-        # I = SI_pred[:, 1]
+        I_pred.backward(self._gradients[0], retain_graph=True)
+        dIdt = self._data.t_scaled.grad.clone()
+        self._data.t_scaled.grad.zero_()
 
 
-        S_residual = dSdt - (-self.alpha * R_t * I)
-        I_residual = dIdt - (self.alpha * (R_t - 1) * I)
+        I = I_pred[:, 0]
+        R_t = I_pred[:, 1]
 
 
-        # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}')
+        # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0]
 
 
-        return S_residual, I_residual
+        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
+        return I_residual
 
 

Dosya farkı çok büyük olduğundan ihmal edildi
+ 52 - 1576
state_data_dinn_sir.ipynb


+ 118 - 0
states_training.py

@@ -0,0 +1,118 @@
+import torch
+import numpy as np
+import csv
+import sys
+
+from src.dataset import PandemicDataset, Norms
+from src.problem import ReducedSIRProblem
+from src.dinn import DINN, Scheduler, Activation
+
+ALPHA = [1/14, 1/5]
+DO_STATES = True
+DO_SYNTHETIC = False
+
+ITERATIONS = 13
+
+state_starting_index = 0
+
+if "1" in sys.argv:
+     state_starting_index = 8
+
+STATE_LOOKUP = {'Schleswig_Holstein' : 2897000,
+                'Hamburg' : 1841000, 
+                'Niedersachsen' : 7982000, 
+                'Bremen' : 569352,
+                'Nordrhein_Westfalen' : 17930000,
+                'Hessen' : 6266000,
+                'Rheinland_Pfalz' : 4085000,
+                'Baden_Wuerttemberg' : 11070000,
+                'Bayern' : 13080000,
+                'Saarland' : 990509,
+                'Berlin' : 3645000,
+                'Brandenburg' : 2641000,
+                'Mecklenburg_Vorpommern' : 1610000,
+                'Sachsen' : 4078000,
+                'Sachsen_Anhalt' : 2208000,
+                'Thueringen' : 2143000}
+
+if DO_SYNTHETIC:
+    alpha = 1/3
+    covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
+    for i in range(ITERATIONS):
+        dataset = PandemicDataset('Synthetic I', 
+                                    ['I'], 
+                                    7.6e6, 
+                                    *covid_data, 
+                                    norm_name=Norms.CONSTANT, 
+                                    use_scaled_time=True)
+
+        problem = ReducedSIRProblem(dataset, alpha)
+        dinn = DINN(2, 
+                    dataset, 
+                    [], 
+                    problem, 
+                    None, 
+                    state_variables=['R_t'], 
+                    hidden_size=100, 
+                    hidden_layers=4, 
+                    activation_layer=torch.nn.Tanh(),
+                    activation_output=Activation.POWER)
+
+        dinn.configure_training(1e-3, 
+                                20000, 
+                                scheduler_class=Scheduler.POLYNOMIAL, 
+                                lambda_physics=1e-6,
+                                verbose=True)
+        dinn.train(verbose=True, do_split_training=True)
+
+        dinn.save_training_process(f'synthetic_{i}')
+        #r_t = dinn.get_output(1).detach().cpu().numpy()
+
+        #with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
+                #writer = csv.writer(csvfile, delimiter=',')
+                #writer.writerow(r_t)
+
+for iteration in range(ITERATIONS):
+    if iteration <= 2:
+         print('skip first three iteration, as it was already done')
+         continue
+    if DO_STATES:
+        for state_idx in range(state_starting_index, state_starting_index + 8):
+            state = list(STATE_LOOKUP.keys())[state_idx]
+            exclude = ['Schleswig_Holstein', 'Hamburg', 'Niedersachsen']
+            if iteration == 3 and state in exclude:
+                print(f'skip in {state} third iteration, as it was already done')
+                continue
+            for i, alpha in enumerate(ALPHA):
+                print(f'training for {state} ({state_idx}), alpha: {alpha}, iter: {iteration}')
+
+                covid_data = np.genfromtxt(f'./datasets/I_RKI_{state}_1_{int(1/alpha)}.csv', delimiter=',')
+                dataset = PandemicDataset(state, ['I'], STATE_LOOKUP[state], *covid_data, norm_name=Norms.CONSTANT, use_scaled_time=True)
+
+                problem = ReducedSIRProblem(dataset, alpha)
+
+                dinn = DINN(2, 
+                            dataset, 
+                            [], 
+                            problem, 
+                            None, 
+                            state_variables=['R_t'], 
+                            hidden_size=100, 
+                            hidden_layers=4, 
+                            activation_layer=torch.nn.Tanh(),
+                            activation_output=Activation.POWER)
+
+                dinn.configure_training(1e-3, 
+                                        25000, 
+                                        scheduler_class=Scheduler.POLYNOMIAL,
+                                        lambda_obs=1e2, 
+                                        lambda_physics=1e-6, 
+                                        verbose=True)
+                dinn.train(verbose=True, do_split_training=True)
+
+                dinn.save_training_process(f'{state}_{i}_{iteration}')
+
+                r_t = dinn.get_output(1).detach().cpu().numpy()
+                with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(r_t)

Dosya farkı çok büyük olduğundan ihmal edildi
+ 254 - 74
synth_dinn_reduced_sir.ipynb


Dosya farkı çok büyük olduğundan ihmal edildi
+ 472 - 90
synth_dinn_sir.ipynb


BIN
visualizations/Baden_Wuerttemberg.png


BIN
visualizations/Bayern.png


BIN
visualizations/Berlin.png


BIN
visualizations/Brandenburg.png


BIN
visualizations/Bremen.png


BIN
visualizations/Hamburg.png


BIN
visualizations/Hessen.png


BIN
visualizations/Mecklenburg_Vorpommern.png


BIN
visualizations/Niedersachsen.png


BIN
visualizations/Nordrhein_Westfalen.png


BIN
visualizations/RKI_SIR_1.png


BIN
visualizations/RKI_SIR_10.png


BIN
visualizations/RKI_SIR_3.png


BIN
visualizations/RKI_SIR_5.png


BIN
visualizations/Rheinland_Pfalz.png


BIN
visualizations/SIRD_synth.png


BIN
visualizations/SIR_RKI_3_alpha.png


BIN
visualizations/SIR_RKI_3_animation.gif


BIN
visualizations/SIR_RKI_3_beta.png


BIN
visualizations/SIR_RKI_3_loss.png


BIN
visualizations/SIR_RKI_5_alpha.png


BIN
visualizations/SIR_RKI_5_animation.gif


BIN
visualizations/SIR_RKI_5_beta.png


BIN
visualizations/SIR_RKI_5_loss.png


BIN
visualizations/SI_synth.png


BIN
visualizations/Saarland.png


BIN
visualizations/Sachsen.png


BIN
visualizations/Schleswig_Holstein.png


BIN
visualizations/Thueringen.png


BIN
visualizations/animations/Baden_Wuerttemberg_animation.gif


+ 0 - 0
visualizations/Baden_Wuerttemberg_synth_sir_animation.gif → visualizations/animations/Baden_Wuerttemberg_synth_sir_animation.gif


BIN
visualizations/animations/Bayern_animation.gif


+ 0 - 0
visualizations/Bayern_synth_sir_animation.gif → visualizations/animations/Bayern_synth_sir_animation.gif


BIN
visualizations/animations/Berlin_animation.gif


+ 0 - 0
visualizations/Berlin_synth_sir_animation.gif → visualizations/animations/Berlin_synth_sir_animation.gif


BIN
visualizations/animations/Brandenburg_animation.gif


+ 0 - 0
visualizations/Brandenburg_synth_sir_animation.gif → visualizations/animations/Brandenburg_synth_sir_animation.gif


BIN
visualizations/animations/Bremen_animation.gif


+ 0 - 0
visualizations/Bremen_synth_sir_animation.gif → visualizations/animations/Bremen_synth_sir_animation.gif


BIN
visualizations/animations/Germany_animation.gif


BIN
visualizations/animations/Hamburg_animation.gif


+ 0 - 0
visualizations/Hamburg_synth_sir_animation.gif → visualizations/animations/Hamburg_synth_sir_animation.gif


BIN
visualizations/animations/Hessen_animation.gif


+ 0 - 0
visualizations/Hessen_synth_sir_animation.gif → visualizations/animations/Hessen_synth_sir_animation.gif


BIN
visualizations/animations/Mecklenburg_Vorpommern_animation.gif


+ 0 - 0
visualizations/Mecklenburg_Vorpommern_synth_sir_animation.gif → visualizations/animations/Mecklenburg_Vorpommern_synth_sir_animation.gif


BIN
visualizations/animations/Niedersachsen_animation.gif


+ 0 - 0
visualizations/Niedersachsen_synth_sir_animation.gif → visualizations/animations/Niedersachsen_synth_sir_animation.gif


BIN
visualizations/animations/Nordrhein_Westfalen_animation.gif


+ 0 - 0
visualizations/Nordrhein_Westfalen_synth_sir_animation.gif → visualizations/animations/Nordrhein_Westfalen_synth_sir_animation.gif


BIN
visualizations/animations/Rheinland_Pfalz_animation.gif


+ 0 - 0
visualizations/Rheinland_Pfalz_synth_sir_animation.gif → visualizations/animations/Rheinland_Pfalz_synth_sir_animation.gif


+ 0 - 0
visualizations/SIR_RKI_1_animation.gif → visualizations/animations/SIR_RKI_1_animation.gif


BIN
visualizations/animations/Saarland_animation.gif


+ 0 - 0
visualizations/Saarland_synth_sir_animation.gif → visualizations/animations/Saarland_synth_sir_animation.gif


+ 0 - 0
visualizations/Sachsen_Anhalt_synth_sir_animation.gif → visualizations/animations/Sachsen_Anhalt_synth_sir_animation.gif


BIN
visualizations/animations/Sachsen_animation.gif


+ 0 - 0
visualizations/Sachsen_synth_sir_animation.gif → visualizations/animations/Sachsen_synth_sir_animation.gif


BIN
visualizations/animations/Schleswig_Holstein_animation.gif


+ 0 - 0
visualizations/Schleswig_Holstein_synth_sir_animation.gif → visualizations/animations/Schleswig_Holstein_synth_sir_animation.gif


+ 0 - 0
visualizations/Thueringen_synth_sir_animation.gif → visualizations/animations/Thueringen_synth_sir_animation.gif


BIN
visualizations/animations/synth_sir_animation.gif


BIN
visualizations/base_params_synth.png


BIN
visualizations/high_alpha_synth.png


BIN
visualizations/high_beta_synth.png


BIN
visualizations/low_alpha_synth.png


BIN
visualizations/low_beta_synth.png


BIN
visualizations/png_img/Baden_Wuerttemberg.png


BIN
visualizations/png_img/Baden_Wuerttemberg_loss.png


BIN
visualizations/png_img/Bayern.png


BIN
visualizations/png_img/Bayern_loss.png


BIN
visualizations/png_img/Berlin.png


BIN
visualizations/png_img/Berlin_loss.png


BIN
visualizations/png_img/Brandenburg.png


BIN
visualizations/png_img/Brandenburg_loss.png


BIN
visualizations/png_img/Bremen.png


BIN
visualizations/png_img/Bremen_loss.png


BIN
visualizations/png_img/Germany_loss.png


BIN
visualizations/png_img/Hamburg.png


BIN
visualizations/png_img/Hamburg_loss.png


BIN
visualizations/png_img/Hessen.png


BIN
visualizations/png_img/Hessen_loss.png


Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor