11 Revize 53d929930a ... af24764e67

Autor SHA1 Zpráva Datum
  phillip.rothenbeck af24764e67 add png, pdf, and gif před 8 měsíci
  phillip.rothenbeck 3b05d7d641 add generalized data transformation algorithm před 8 měsíci
  phillip.rothenbeck 7555f1b41e clean up před 8 měsíci
  phillip.rothenbeck 5f34dd8418 clean up reduced před 8 měsíci
  phillip.rothenbeck 0a7b829650 add paper layout + scatter function před 8 měsíci
  phillip.rothenbeck c38d74dc4c add scaling, norm, optimizer and scheduler choosing choosing před 8 měsíci
  phillip.rothenbeck b724dac3ad add norms and scaling před 8 měsíci
  phillip.rothenbeck 3cd104c9ae training pipelines před 8 měsíci
  phillip.rothenbeck eefda9b551 plot skripts for thesis před 8 měsíci
  phillip.rothenbeck c005bdab3e seperatly train R_t for Germany před 8 měsíci
  phillip.rothenbeck 483b1f70b3 preprocess all data in one notebook před 8 měsíci

+ 8 - 1
.gitignore

@@ -1,8 +1,15 @@
-# push no pycache
+# push no pycache and config files
 **__pycache__**
+*.json
+
+# push no batch skripts
+*.sh
 
 # push no training result data
 *.csv
+*.pdf
+*.out
+*.gif
 
 # push no RKI data
 *_RKI_*.csv

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 7 - 11
data.ipynb


+ 52 - 0
germany_training.py

@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import csv
+
+from src.dataset import PandemicDataset, Norms
+from src.problem import ReducedSIRProblem
+from src.dinn import DINN, Scheduler, Activation
+
+ALPHA = [1/14, 1/5]
+NORM = [Norms.POPULATION, Norms.CONSTANT]
+
+ITERATIONS = 10
+
+for iteration in range(ITERATIONS):
+    for i, alpha in enumerate(ALPHA):
+        print(f'training for Germany, alpha: {alpha}, iter: {iteration}')
+
+        covid_data = np.genfromtxt(f'./datasets/I_RKI_Germany_1_{int(1/alpha)}.csv', delimiter=',')
+        dataset = PandemicDataset('Germany', 
+                                  ['I'], 
+                                  83100000, 
+                                  *covid_data, 
+                                  norm_name=NORM[i], 
+                                  C=10**6, 
+                                  use_scaled_time=True)
+        problem = ReducedSIRProblem(dataset, alpha)
+
+        dinn = DINN(2, 
+                    dataset, 
+                    [], 
+                    problem, 
+                    None, 
+                    state_variables=['R_t'], 
+                    hidden_size=100, 
+                    hidden_layers=4, 
+                    activation_layer=torch.nn.Tanh(),
+                    activation_output=Activation.POWER)
+
+        dinn.configure_training(1e-3, 
+                                25000, 
+                                scheduler_class=Scheduler.POLYNOMIAL, 
+                                lambda_obs=1e4,
+                                lambda_physics=1e-6, 
+                                verbose=True)
+        dinn.train(verbose=True, do_split_training=True, start_split=15000)
+
+        dinn.save_training_process(f'Germany_{i}_{iteration}')
+
+        r_t = dinn.get_output(1).detach().cpu().numpy()
+        with open(f'./results/Germany_{i}_{iteration}.csv', 'w', newline='') as csvfile:
+            writer = csv.writer(csvfile, delimiter=',')
+            writer.writerow(r_t)

+ 109 - 0
plot_datasets.py

@@ -0,0 +1,109 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+DS_DIR = './datasets/'
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+
+STATE_LOOKUP = {'Schleswig_Holstein' : 'Schleswig-Holstein',
+                'Hamburg' : 'Hamburg', 
+                'Niedersachsen' : 'Niedersachsen', 
+                'Bremen' : 'Bremen',
+                'Nordrhein_Westfalen' : 'North Rhine-Westphalia',
+                'Hessen' : 'Hessen',
+                'Rheinland_Pfalz' : 'Rhineland-Palatinate',
+                'Baden_Wuerttemberg' : 'Baden-Württemberg',
+                'Bayern' : 'Bavaria',
+                'Saarland' : 'Saarland',
+                'Berlin' : 'Berlin',
+                'Brandenburg' : 'Brandenburg',
+                'Mecklenburg_Vorpommern' : 'Mecklenburg-Western Pomerania',
+                'Sachsen' : 'Saxony',
+                'Sachsen_Anhalt' : 'Saxony-Anhalt',
+                'Thueringen' : 'Thuringia'}
+
+plotter = Plotter()
+
+data = []
+in_text_data = [np.genfromtxt(DS_DIR + f'SIR_data.csv', delimiter=',')[1:], 1]
+for state in STATE_LOOKUP.keys():
+    # print(f"plot {state}")
+    state_data_5 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_5.csv', delimiter=',')[1]
+    state_data_14 = np.genfromtxt(DS_DIR + f'I_RKI_{state}_1_14.csv', delimiter=',')[1]
+    sir_data = np.genfromtxt(DS_DIR + f'SIR_RKI_{state}_1_14.csv', delimiter=',')[1:]
+    data.append(sir_data)
+    t = np.arange(0, 1200, 1)
+
+    if state in ["Schleswig_Holstein", "Berlin", "Thueringen"]:
+        in_text_data.append(sir_data)
+
+    plotter.plot(t, 
+                 [state_data_14, state_data_5], 
+                 [r'$\alpha=\frac{1}{14}$', r'$\alpha=\frac{1}{5}$'],
+                 f'{state}_datasets',
+                 f'{STATE_LOOKUP[state]}',
+                 (12,6),
+                 xlabel='time / days',
+                 ylabel='amount of people')
+    
+do_log=False
+plotter.cluster_plot(t,
+                     data[:6],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6,6),
+                     "state_sir_cluster_1",
+                     list(STATE_LOOKUP.values())[:6],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.05,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+plotter.cluster_plot(t,
+                     data[7:],
+                     [r'$S$', r'$I$', r'$R$'],
+                     (3, 3),
+                     (6,6),
+                     "state_sir_cluster_2",
+                     list(STATE_LOOKUP.values())[7:],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     y_log_scale=do_log,
+                     add_y_space=0.03,
+                     number_of_legend_columns=3,
+                     same_axes=False,
+                     ylim=(0, 1.85e7))
+
+germany_data = np.genfromtxt(DS_DIR + f'SIR_RKI_Germany_1_14.csv', delimiter=',')[1:]
+in_text_data[1] = germany_data
+
+plotter.plot(t,
+             germany_data,
+             [r'$S$', r'$I$', r'$R$'],
+             'germany_single_sir',
+             'Germany',
+             (6,6),
+             plot_legend=False,
+             xlabel='time / days',
+             ylabel='amount of people',)
+
+plotter.cluster_plot(t, 
+                     in_text_data,
+                     [r'$S$', r'$I$', r'$R$'],
+                     (2, 3),
+                     (6, 6),
+                     "in_text_SIR",
+                     ["synthetic SIR data", "Germany", 'Schleswig Holstein', 'Berlin', 'Thuringia'],
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     legend_loc=(0.51, 0.8),
+                     add_y_space=0,
+                     same_axes=False,
+                     free_axis=(0, 1),
+                     plot_all_labels=False)
+
+
+

+ 302 - 0
plot_results.py

@@ -0,0 +1,302 @@
+import numpy as np
+import pandas as pd
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+I_PRED_SRC_DIR = SRC_DIR + 'I_predictions/'
+SIM_DIR = './visualizations/'
+
+def get_error(y, y_ref):
+    err = []
+    for i in range(len(y)):
+        diff = y[i] - y_ref
+        err.append(np.linalg.norm(diff) / np.linalg.norm(y_ref))
+    return np.array(err).mean(axis=0)
+
+STATE_LOOKUP = {'Schleswig_Holstein' : (79.5,0.0849),
+                'Hamburg' : (84.5, 0.0948),
+                'Niedersachsen' : (77.6, 0.0774),
+                'Bremen' : (88.3,0.0933),
+                'Nordrhein_Westfalen' : (79.5,0.0777),
+                'Hessen' : (75.8,0.1017),
+                'Rheinland_Pfalz' : (75.6,0.0895),
+                'Baden_Wuerttemberg' : (74.5,0.0796),
+                'Bayern' : (75.1,0.0952),
+                'Saarland' : (82.4,0.1080),
+                'Berlin' : (78.1,0.0667),
+                'Brandenburg' : (68.1,0.0724),
+                'Mecklenburg_Vorpommern' : (74.7,0.0540),
+                'Sachsen' : (65.1,0.1109),
+                'Sachsen_Anhalt' : (74.1,0.0785),
+                'Thueringen' : (70.3,0.0837),
+                'Germany' : (76.4, 0.0804)}
+
+state_names = ['Schleswig-Holstein',
+                'Hamburg', 
+                'Lower Saxony', 
+                'Bremen',
+                'North Rhine-Westphalia',
+                'Hesse',
+                'Rhineland-Palatinate',
+                'Baden-Württemberg',
+                'Bavaria',
+                'Saarland',
+                'Berlin',
+                'Brandenburg',
+                'Mecklenburg-Western Pomerania',
+                'Saxony',
+                'Saxony-Anhalt',
+                'Thuringia', 
+                'Germany']
+
+plotter = Plotter(additional_colors=['yellow', 'cyan', 'magenta', ])
+
+# plot results for alpha and beta
+
+print("Visualizing Alpha and Beta results")
+
+# synth
+param_matrix = np.genfromtxt(SRC_DIR + f'synthetic_parameters.csv', delimiter=',')
+mean = param_matrix.mean(axis=0)
+std = param_matrix.std(axis=0)
+
+print("States Table form:")
+print('{0:.4f}'.format(1/3), "&", '{0:.4f}'.format(mean[0]), "&", '{0:.4f}'.format(std[0]), "&", '{0:.4f}'.format(1/2), "&", '{0:.4f}'.format(mean[1]), "&", '{0:.4f}'.format(std[1]), "\\\ ")
+
+plotter.scatter(np.arange(1, 6, 1), [param_matrix[:,0], param_matrix[:,1]], [r"$\alpha$", r"$\beta$"], (7,3.5), 'reproducability', '', true_values=[1/3, 1/2], xlabel='iteration')
+
+vaccination_ratios = []
+mean_std_parameters = {}
+for state in STATE_LOOKUP.keys():
+    state_matrix = np.genfromtxt(SRC_DIR + f'{state}_parameters.csv', delimiter=',')
+    mean = state_matrix.mean(axis=0)
+    std = state_matrix.std(axis=0)
+    mean_std_parameters.update({state : (mean, std)})
+    vaccination_ratios.append(STATE_LOOKUP[state][0])
+
+values = np.array(list(mean_std_parameters.values()))
+means = values[:,0]
+stds = values[:,1]
+alpha_means = means[:,0]
+beta_means = means[:,1]
+alpha_stds = stds[:,0]
+beta_stds = stds[:,1]
+
+print(f"Vaccination corr: {np.corrcoef(beta_means, vaccination_ratios)[0, 1]}")
+vaccination_ratios = vaccination_ratios[:-1]
+sn = np.array(state_names[:-1]).copy()
+sn[12] = "MWP"
+plotter.scatter(sn, 
+                [alpha_means[:-1], beta_means[:-1]], 
+                [r'$\alpha$', r'$\beta$', ],
+                (12, 6), 
+                'mean_std_alpha_beta_res',
+                '',
+                std=[alpha_stds[:-1], beta_stds[:-1]],
+                true_values=[alpha_means[-1], beta_means[-1]],
+                true_label='Germany',
+                xlabel_rotation=60,
+                plot_legend=True,
+                legend_loc="lower right")
+
+print("States Table form:")
+for i, state in enumerate(STATE_LOOKUP.keys()):
+    print(state_names[i], "&", '{0:.3f}'.format(alpha_means[i]), "{\\tiny $\\pm",'{0:.3f}'.format(alpha_stds[i]), "$}", "&", '{0:.3f}'.format(beta_means[i]), "{\\tiny $\\pm", '{0:.3f}'.format(beta_stds[i]), "$}", "&", STATE_LOOKUP[state][1], "&", '{0:.3f}'.format(beta_means[i]-beta_means[16]), "&", '{0:.1f}'.format(STATE_LOOKUP[state][0]), "\\\ ")
+
+print()
+# plot results for reproduction number
+
+# synth
+synth_iterations = []
+for i in range(10):   
+    synth_iterations.append(np.genfromtxt(SRC_DIR + f'synthetic_{i}.csv', delimiter=','))
+
+synth_matrix = np.array(synth_iterations)
+t = np.arange(0, len(synth_matrix[0]), 1)
+synth_r_t = np.zeros(150, dtype=np.float64)
+for i, time in enumerate(range(150)):
+    synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
+
+print(f"Synthetic error R_t: {get_error(synth_matrix.mean(axis=0), synth_r_t)}")
+plotter.plot(t, 
+             [synth_matrix.mean(axis=0), synth_r_t],
+             [r'$\mathcal{R}_t$', r'true $\mathcal{R}_t$'], 
+             f"synthetic_R_t_statistics", 
+             r"Synthetic data $\mathcal{R}_t$", 
+             (9, 6),
+             fill_between=[synth_matrix.std(axis=0)],
+             xlabel="time / days")
+
+pred_synth = np.genfromtxt(I_PRED_SRC_DIR + f'synthetic_0_I_prediction.csv', delimiter=',')
+
+print(f"Synthetic error I: {get_error(pred_synth[2], pred_synth[1])}")
+
+plotter.plot(pred_synth[0], 
+             [pred_synth[2], pred_synth[1]],
+             [r'prediction $I$', r'true $I$'], 
+             f"synthetic_I_prediction", 
+             r"Synthetic data $I$ prediction", 
+             (9, 6),
+             xlabel="time / days",
+             ylabel='amount of people')
+
+EVENT_LOOKUP = {'start of vaccination' : 455,
+                'alpha variant' : 357,
+                'delta variant' : 473,
+                'omicron variant' : 663}
+
+
+ALPHA = [1 / 14, 1 / 5]
+
+
+cluster_counter = 1
+cluster_idx = 0
+
+in_text_r_t_mean = []
+in_text_r_t_std = []
+in_text_I = []
+in_text_I_std = []
+
+cluster_r_t_mean = []
+cluster_r_t_std = []
+cluster_I = []
+cluster_I_std = []
+cluster_states = []
+
+for k, state in enumerate(STATE_LOOKUP.keys()):
+
+    if state == "Thueringen":
+        l = 1
+    elif state == "Bremen":
+        l = 0
+    # data fetch arrays
+    r_t = []
+    pred_i = []
+    true_i = []
+    cluster_states.append(state_names[k])
+    cluster_r_t_mean.append([])
+    cluster_r_t_std.append([])
+    cluster_I.append([])
+    cluster_I_std.append([np.zeros(1200), np.zeros(1200)])
+    if state == "Thueringen" or state == "Bremen":
+        in_text_r_t_mean.append([])
+        in_text_r_t_std.append([])   
+        in_text_I.append([])
+        in_text_I_std.append([np.zeros(1200), np.zeros(1200)])
+    for i, alpha in enumerate(ALPHA):
+        iterations = []
+        predictions = []
+        true = []
+        for j in range(10): 
+            iterations.append(np.genfromtxt(SRC_DIR + f'{state}_{i}_{j}.csv', delimiter=','))
+            if (k >= 3 and j == 3) or j > 3:
+                data = np.genfromtxt(I_PRED_SRC_DIR + f'{state}_{i}_{j}_I_prediction.csv', delimiter=',')
+                predictions.append(data[2])
+                true = data[1]
+        iterations = np.array(iterations)
+        r_t.append(iterations)
+
+        predictions = np.array(predictions)
+        pred_i.append(predictions)
+        true_i.append(true)
+
+        cluster_r_t_mean[cluster_counter-1].append(iterations.mean(axis=0))
+        cluster_r_t_std[cluster_counter-1].append(iterations.std(axis=0))
+        if state == "Thueringen" or state == "Bremen":
+            in_text_r_t_mean[l].append(iterations.mean(axis=0))
+            in_text_r_t_std[l].append(iterations.std(axis=0))
+    if state == "Thueringen" or state == "Bremen":
+        in_text_I[l].append(true_i[0])
+        in_text_I[l].append(true_i[1])
+        in_text_I[l].append(pred_i[0].mean(axis=0))
+        in_text_I[l].append(pred_i[1].mean(axis=0))
+        in_text_I_std[l].append(pred_i[0].std(axis=0))
+        in_text_I_std[l].append(pred_i[1].std(axis=0))
+
+    cluster_I[cluster_counter-1].append(true_i[0])
+    cluster_I[cluster_counter-1].append(true_i[1])
+    cluster_I[cluster_counter-1].append(pred_i[0].mean(axis=0))
+    cluster_I[cluster_counter-1].append(pred_i[1].mean(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[0].std(axis=0))
+    cluster_I_std[cluster_counter-1].append(pred_i[1].std(axis=0))
+
+    # plot
+    print(f"{state_names[k]} & {'{0:.3f}'.format(get_error(pred_i[0], true_i[0]))} & {'{0:.3f}'.format(get_error(pred_i[1], true_i[1]))} & \phantom{{0}} & {(r_t[0] > 1).sum(axis=1).mean()} & {(r_t[1] > 1).sum(axis=1).mean()} & {'{0:.3f}'.format(r_t[0].max(axis=1).mean())} & {'{0:.3f}'.format(r_t[1].max(axis=1).mean())}\\\ ")
+    if len(cluster_states) == 4 and state != "Thueringen" or len(cluster_states) == 5 and state == "Germany":
+        t = np.arange(0, 1200, 1)
+        if len(cluster_states) == 5:
+            y_lim_exception = 4
+        else:
+            y_lim_exception = None
+        plotter.cluster_plot(t,
+                            cluster_r_t_mean,
+                            [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                            (len(cluster_states), 1),
+                            (9, 6),
+                            f'r_t_cluster_{cluster_idx}',
+                            [state + r"  $\mathcal{R}_t$" for state in cluster_states],
+                            fill_between=cluster_r_t_std,
+                            event_lookup=EVENT_LOOKUP,
+                            xlabel='time / days',
+                            ylim=(0.3, 2.0),
+                            legend_loc=(0.53, 0.992),
+                            number_of_legend_columns=3)
+        plotter.cluster_plot(t,
+                             cluster_I,
+                             [r"true $I$ $\alpha=\frac{1}{14}$", 
+                             r"true $I$ $\alpha=\frac{1}{5}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                             r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                             (len(cluster_states), 1),
+                             (9, 6),
+                             f'I_cluster_{cluster_idx}',
+                             [state + r" $I$ prediction" for state in cluster_states],
+                             fill_between=cluster_I_std,
+                             xlabel='time / days',
+                             ylabel='amount of people',
+                             same_axes=False,
+                             ylim=(0, 600000),
+                             legend_loc=(0.55, 0.992),
+                             number_of_legend_columns=2,
+                             y_lim_exception=y_lim_exception)
+        cluster_counter = 0
+        cluster_idx += 1
+        cluster_r_t_mean = []
+        cluster_r_t_std = []
+        cluster_I = []
+        cluster_I_std = []
+        cluster_states = []
+    cluster_counter += 1
+
+plotter.cluster_plot(t,
+                    in_text_r_t_mean,
+                    [r"$\alpha=\frac{1}{14}$", r"$\alpha=\frac{1}{5}$"],
+                    (2, 1),
+                    (9, 6),
+                    f'r_t_cluster_intext',
+                    [state + r"  $\mathcal{R}_t$" for state in ['Bremen', 'Thuringia']],
+                    fill_between=in_text_r_t_std,
+                    event_lookup=EVENT_LOOKUP,
+                    xlabel='time / days',
+                    ylim=(0.3, 2.0),
+                    legend_loc=(0.53, 0.999),
+                    add_y_space=0.08,
+                    number_of_legend_columns=3)
+plotter.cluster_plot(t,
+                     in_text_I,
+                     [r"true $I$ $\alpha=\frac{1}{14}$", 
+                     r"true $I$ $\alpha=\frac{1}{5}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{14}$", 
+                     r"prediction $I$ $\alpha=\frac{1}{5}$"],
+                     (2, 1),
+                     (9, 6),
+                     f'I_cluster_intext',
+                     [state + r" $I$ prediction" for state in ['Bremen', 'Thuringia']],
+                     fill_between=in_text_I_std,
+                     xlabel='time / days',
+                     ylabel='amount of people',
+                     ylim=(0, 600000),
+                     legend_loc=(0.55, 0.999),
+                     add_y_space=0.08,
+                     number_of_legend_columns=2)
+

+ 21 - 0
plot_training_metrics.py

@@ -0,0 +1,21 @@
+import numpy as np
+from src.plotter import Plotter
+
+SRC_DIR = './results/'
+SIM_DIR = './visualizations/'
+MET_DIR = 'training_metrics/'
+
+plotter = Plotter()
+# plot synthetic loss
+physics_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_physics_loss.csv', delimiter=',')
+obs_loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_obs_loss.csv', delimiter=',')
+loss = np.genfromtxt(SRC_DIR + MET_DIR + f'Germany_1_0_loss.csv', delimiter=',')
+
+t = np.arange(0, len(loss), 1)
+plotter.plot(t, 
+             [loss, physics_loss, obs_loss], 
+             ['loss', 'physics loss', 'data loss'], 
+             'Germany_1_loss',
+             'Loss', 
+             (6, 6),
+             y_log_scale=True)

+ 55 - 8
src/dataset.py

@@ -1,4 +1,10 @@
 import torch
+from enum import Enum
+
+class Norms(Enum):
+    POPULATION=0
+    MIN_MAX=1
+    CONSTANT=2
 
 class PandemicDataset:
     def __init__(self, 
@@ -6,7 +12,10 @@ class PandemicDataset:
                  group_names:list, 
                  N: int, 
                  t, 
-                 *groups):
+                 *groups, 
+                 norm_name=Norms.MIN_MAX,
+                 C = 10**5,
+                 use_scaled_time=False):
         """Class to hold all data for one training process.
 
         Args:
@@ -15,19 +24,39 @@ class PandemicDataset:
             t (np.array): Array of timesteps.
             *groups (np.array): Arrays of size data for each group for each timestep.
         """
-
         if torch.cuda.is_available():
             self.device_name = 'cuda'
         else:
             self.device_name = 'cpu'
 
+        match norm_name:
+            case Norms.POPULATION:
+                self.__norm = self.__population_norm
+                self.__denorm = self.__population_denorm
+            case Norms.MIN_MAX:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+            case Norms.CONSTANT:
+                self.__norm = self.__constant_norm
+                self.__denorm = self.__constant_denorm
+            case _:
+                self.__norm = self.__min_max_norm
+                self.__denorm = self.__min_max_denorm
+
         self.name = name
         self.N = N
         self.t_init = t.min()
         self.t_final = t.max()
+        self.C = C
 
         self.t_raw = torch.tensor(t, requires_grad=True, device=self.device_name)
-        self.t_batch = self.t_raw.view(-1, 1).float()
+
+        self.t_scaled = ((self.t_raw - self.t_init) / (self.t_final - self.t_init)).detach().requires_grad_()
+        self.use_scaled_time = use_scaled_time
+        if use_scaled_time:
+            self.t_batch = self.t_scaled.view(-1, 1).float()
+        else:
+            self.t_batch = self.t_raw.view(-1, 1).float()
 
         self.__group_dict = {}
         for i, name in enumerate(group_names):
@@ -39,7 +68,7 @@ class PandemicDataset:
         
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
-        self.__norms = [(self.__groups[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(len(groups))]
+        self.__norms = self.__norm(self.__groups)
 
     @property
     def number_groups(self):
@@ -52,14 +81,32 @@ class PandemicDataset:
     @property
     def group_names(self):
         return self.__group_names
+    
+    def __population_norm(self, data):
+        return [(data[i] / self.N) for i in range(self.number_groups)]
+    
+    def __population_denorm(self, data):
+        return [(data[i] * self.N) for i in range(self.number_groups)]
+
+    def __min_max_norm(self, data):
+        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+    
+    def __min_max_denorm(self, data):
+        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
+    
+    def __constant_norm(self, data):
+        return [(data[i] / self.C) for i in range(self.number_groups)]
+
+    def __constant_denorm(self, data):
+        return [(data[i] * self.C) for i in range(self.number_groups)]
 
     def get_normalized_data(self, data:list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
+        return self.__norm(data)
     
-    def get_denormalized_data(self, normalized_data:list):
-        assert len(normalized_data) == self.number_groups, f'normalized_data parameter needs same length as there are groups in the dataset ({self.number_groups})'
-        return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * normalized_data[i]) for i in range(self.number_groups)]
+    def get_denormalized_data(self, data:list):
+        assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
+        return self.__denorm(data)
 
     def get_group(self, name:str):
         return self.__groups[self.__group_dict[name]]

+ 129 - 32
src/dinn.py

@@ -1,10 +1,33 @@
 import torch
+import csv
 import numpy as np
 
+from enum import Enum
+
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .plotter import Plotter
 
+class Optimizer(Enum):
+    ADAM=0
+
+class Scheduler(Enum):
+    CYCLIC=0
+    CONSTANT=1
+    LINEAR=2
+    POLYNOMIAL=3
+
+class Activation(Enum):
+    LINEAR=0
+    POWER=1
+
+def linear(x):
+    return x
+        
+def power(x):
+    return torch.float_power(x, 2)
+
+
 class DINN:
     class NN(torch.nn.Module):
         def __init__(self, 
@@ -12,9 +35,12 @@ class DINN:
                      input_size: int,
                      hidden_size: int,
                      hidden_layers: int, 
-                     activation_layer, 
+                     activation_layer,
                      t_init,
-                     t_final) -> None:
+                     t_final,
+                     output_activation_function=Activation.LINEAR,
+                     use_glorot_initialization = False,
+                     use_t_scaled=True) -> None:
             """Neural Network
 
             Args:
@@ -26,21 +52,39 @@ class DINN:
             """
             super(DINN.NN, self).__init__()
 
+            if output_activation_function == Activation.LINEAR:
+                self.out_activation = linear
+            elif output_activation_function == Activation.POWER:
+                self.out_activation = power
+            else:
+                print('Set output activation to default: linear')
+                self.out_activation = self.linear
+
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
+
+            if use_glorot_initialization:
+                torch.nn.init.xavier_uniform_(self.input[0].weight)
+                for i in range(hidden_layers):
+                    torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
+                torch.nn.init.xavier_uniform_(self.output.weight)
             
             self.__t_init = t_init
             self.__t_final = t_final
+            self.__use_t_scaled = use_t_scaled
 
         def forward(self, t):
             # normalize input
-            t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
-            x = self.input(t_scaled)
+            if self.__use_t_scaled:
+                t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
+            else:
+                t_forward = t
+            x = self.input(t_forward)
             x = self.hidden(x)
             x = self.output(x)
-            return x
-
+            return self.out_activation(x)
+    
     def __init__(self,
                  output_size: int,
                  data: PandemicDataset,
@@ -52,7 +96,9 @@ class DINN:
                  input_size=1, 
                  hidden_size=20, 
                  hidden_layers=7, 
-                 activation_layer=torch.nn.ReLU()) -> None:
+                 activation_layer=torch.nn.ReLU(),
+                 activation_output=Activation.LINEAR,
+                 use_glorot_initialization = False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
 
@@ -78,9 +124,12 @@ class DINN:
                              input_size, 
                              hidden_size, 
                              hidden_layers, 
-                             activation_layer, 
+                             activation_layer,
                              data.t_init, 
-                             data.t_final)
+                             data.t_final,
+                             activation_output,
+                             use_glorot_initialization=use_glorot_initialization,
+                             use_t_scaled=data.use_scaled_time)
         self.model = self.model.to(self.device)
         self.data = data
         self.parameter_regulator = parameter_regulator
@@ -131,8 +180,21 @@ class DINN:
             list: list of regulated parameters
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
+
     
-    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+    def get_output(self, index):
+        output = self.model(self.data.t_batch)
+        return output[:, index]
+    
+    def configure_training(self, 
+                           lr:float, 
+                           epochs:int, 
+                           optimizer_class=Optimizer.ADAM, 
+                           scheduler_class=Scheduler.CYCLIC, 
+                           scheduler_factor = 1, 
+                           lambda_obs = 1,
+                           lambda_physics = 1,
+                           verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
         Args:
@@ -144,36 +206,38 @@ class DINN:
         """
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         self.epochs = epochs
-        match optimizer_name:
-            case 'Adam':
+        self.lambda_obs = lambda_obs
+        self.lambda_physics = lambda_physics
+        match optimizer_class:
+            case Optimizer.ADAM:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
             case _:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
                     print('---------------------------------')
-                optimizer_name = 'Adam'
+                optimizer_class = Optimizer.ADAM
 
-        match scheduler_name:
-            case 'CyclicLR':
+        match scheduler_class:
+            case Scheduler.CYCLIC:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-            case 'ConstantLR':
+            case Scheduler.CONSTANT:
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
-            case 'LinearLR':
+            case Scheduler.LINEAR:
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
-            case 'PolynomialLR':
+            case Scheduler.POLYNOMIAL:
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
                     print('---------------------------------')
-                scheduler_name = 'CyclicLR'
+                scheduler_class = Scheduler.CYCLIC
 
         if verbose:
-            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
 
         self.__is_configured = True
 
@@ -181,7 +245,9 @@ class DINN:
     def train(self, 
               create_animation=False,
               animation_sample_rate=500,
-              verbose=False):
+              verbose=False,
+              do_split_training=False,
+              start_split=10000):
         """Training routine for the DINN.
 
         Args:
@@ -203,20 +269,27 @@ class DINN:
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
-
             self.optimizer.zero_grad()
 
             # calculate loss from the differential system
             loss_physics = 0
             for residual in residuals:
                 loss_physics += torch.mean(torch.square(residual))
+            loss_physics *= self.lambda_physics
 
             # calculate loss from the dataset
             loss_obs = 0
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+            loss_obs *= self.lambda_obs
             
-            loss = loss_obs + loss_physics
+            if do_split_training:
+                if epoch < start_split:
+                    loss = loss_obs
+                else:
+                    loss = loss_obs + loss_physics
+            else:
+                loss = loss_obs + loss_physics
 
             loss.backward()
             self.optimizer.step()
@@ -291,7 +364,8 @@ class DINN:
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    ['prediction', 'ground truth'], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   list(self.parameters_tilda.items())[i][0], 
+                                   (6,6), 
                                    is_background=[0, 1], 
                                    xlabel='epochs')
             else:
@@ -302,19 +376,42 @@ class DINN:
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   xlabel='epochs', 
                                   plot_legend=False)
+                
+    def save_training_process(self, title, save_predictions = True):
+        losses = {'loss' : self.losses,
+                  'obs_loss' : self.obs_losses,
+                  'physics_loss' : self.physics_losses}
+        for loss in losses.keys():
+            with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(losses[loss])
+
+        for i, parameter in enumerate(self.parameters):
+            with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(parameter)
+        if save_predictions:
+            prediction = self.model(self.data.t_batch)
+            for i, group in enumerate(self.data.group_names):
+                t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
+                true = self.data.get_group(group).detach().cpu().numpy()
+                pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
+                print(t.shape, true.shape)
+                with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(t)
+                    writer.writerow(true)
+                    writer.writerow(pred)
 
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
-        groups = [prediction[:, i] for i in range(self.data.number_groups)]
-        fore_background = [0] + [1 for _ in groups]
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
-            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
-                              [prediction[:, i]] + groups,
-                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              [prediction[:, i]],
+                              [self.__state_variables[i-self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               self.__state_variables[i-self.data.number_groups],
-                              is_background=fore_background,
                               figure_shape=(12, 6),
                               plot_legend=True,
                               xlabel='time / days')

+ 224 - 17
src/plotter.py

@@ -1,9 +1,12 @@
 import os
 import torch
 import imageio
+import numpy as np
 import matplotlib.pyplot as plt
+import matplotlib.ticker as ticker 
 
 from matplotlib import rcParams
+from itertools import cycle
 
 FRAME_DIR = 'visualizations/temp/'
 VISUALISATION_DIR = 'visualizations/'
@@ -13,7 +16,7 @@ INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 
 class Plotter:
-    def __init__(self, additional_colors=[], font_size=12, font='Comfortaa', font_color='#595959') -> None:
+    def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
 
         Args:
@@ -24,14 +27,26 @@ class Plotter:
         """
         self.__colors = [SUSCEPTIBLE_COLOR, INFECTIOUS_COLOR, REMOVED_COLOR] + additional_colors
 
+        self.__lines_styles = ['solid', 'dotted', 'dashdot', (5, (10, 3)), (0, (5, 1)), (0, (3, 5, 1, 5)), (0, (3, 1, 1, 1)), (0, (3, 5, 1, 5, 1, 5)), (0, (3, 1, 1, 1, 1, 1))]
+
+        self.__marker_styles = ['o', '^', 's']
+
+        rcParams['text.usetex'] = True
+        rcParams['text.color'] = font_color
+
         rcParams['font.family'] = font
         rcParams['font.size'] = font_size
 
-        rcParams['text.color'] = font_color
+        rcParams['pgf.texsystem'] = 'pdflatex'
+        rcParams['pgf.rcfonts'] = False
+
         rcParams['axes.labelcolor'] = font_color
         rcParams['xtick.color'] = font_color
         rcParams['ytick.color'] = font_color
 
+        plt.rc("text", usetex=True)
+        plt.rc("text.latex", preamble=r"\usepackage{amssymb} \usepackage{wasysym}")
+
         self.__frames = []
 
     def __generate_figure(self, shape=(4, 4)):
@@ -43,7 +58,7 @@ class Plotter:
         Returns:
             Figure: plt.Figure that was generated.
         """
-        fig = plt.figure(figsize=shape)
+        fig = plt.figure(figsize=shape, constrained_layout=True)
         return fig
 
     def reset_animation(self):
@@ -51,6 +66,7 @@ class Plotter:
         """
         self.__frames = []
     
+    #TODO comments
     def plot(self, 
              x, 
              y:list, 
@@ -58,16 +74,19 @@ class Plotter:
              file_name:str,
              title:str, 
              figure_shape:tuple, 
+             event_lookup={},
              is_frame=False, 
              is_background=[], 
+             fill_between=[],
              plot_legend=True, 
              y_log_scale=False, 
              lw=3, 
              legend_loc='best', 
              ylim=(None, None),
+             number_xlabels = 5,
              xlabel='',
              ylabel='', 
-             xlabel_rotation=0):
+             xlabel_rotation=None):
         """Plotting method.
 
         Args:
@@ -87,20 +106,26 @@ class Plotter:
             xlabel (str, optional): Label for the x axis. Defaults to ''.
             ylabel (str, optional): Label for the y axis. Defaults to ''.
         """
-        assert len(y) == len(labels), "There must be the same amount of labels as there are plots."
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         assert len(is_background) == 0 or len(y) == len(is_background), "If given the back_foreground list must have the same length as labels has."
         fig = self.__generate_figure(shape=figure_shape)
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
-        ax.yaxis.set_tick_params(length=0, which='both')
-        ax.xaxis.set_tick_params(length=0, which='both')
-        ax.grid(which='major', c='black', lw=0.2, ls='-')
+        #ax.yaxis.set_tick_params(length=0, which='both')
+        #ax.xaxis.set_tick_params(length=0, which='both')
+
+        #ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
 
-        for spine in ('top', 'right', 'bottom', 'left'):
-            ax.spines[spine].set_visible(False)
+        #for spine in ('top', 'right', 'bottom', 'left'):
+         #   ax.spines[spine].set_visible(False)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
 
+        linecycler = cycle(self.__lines_styles)
+        j = 0
         for i, array in enumerate(y):
             alpha = 1
             if len(is_background) != 0:
@@ -111,12 +136,19 @@ class Plotter:
                 data = array.cpu().detach().numpy()
             else:
                 data = array
-            
-            if len(is_background) != 0 and alpha == 1:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle='dashed', c=self.__colors[i % len(self.__colors)])
-            else:
-                ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, c=self.__colors[i % len(self.__colors)])
 
+            space = int(len(x) / number_xlabels)
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
+        
+            ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
+            if i < len(fill_between):
+                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+            j = i
+
+        for event in event_lookup.keys():
+            j += 1
+            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
+            
         if plot_legend:
             plt.legend(loc=legend_loc)
 
@@ -132,7 +164,8 @@ class Plotter:
         if ylabel != '':
             plt.ylabel(ylabel)
 
-        plt.xticks(rotation=xlabel_rotation)
+        if xlabel_rotation != None:
+            plt.xticks(rotation=xlabel_rotation)
  
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
@@ -144,7 +177,181 @@ class Plotter:
             self.__frames.append(imageio.imread(frame_path))
             os.remove(frame_path)
         else:
-            plt.savefig(VISUALISATION_DIR + f'{file_name}.png')
+            plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
+
+    def cluster_plot(self, 
+                     x, 
+                     y:list, 
+                     labels, 
+                     shape, 
+                     plots_shape, 
+                     file_name:str, 
+                     titles:list, 
+                     number_xlabels=5,
+                     lw=3, 
+                     fill_between=[],
+                     event_lookup={},
+                     xlabel='',
+                     ylabel='',
+                     ylim=(None, None),
+                     y_lim_exception=None,
+                     y_log_scale=False, 
+                     legend_loc=(0.5,0.992),
+                     add_y_space=0.05,
+                     number_of_legend_columns=1,
+                     same_axes=True,
+                     free_axis=(None, None),
+                     plot_all_labels = True):
+        real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
+        fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
+        plot_idx = 0
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        if 1 in shape:
+            for i in range(len(axes)):
+                linecycler = cycle(self.__lines_styles)
+                colorcycler = cycle(self.__colors)
+                for j, array in enumerate(y[i]):
+                    color = next(colorcycler)
+                    if torch.is_tensor(array):
+                        data = array.cpu().detach().numpy()
+                    else:
+                        data = array
+
+                    space = int(len(x) / number_xlabels)
+                    axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                    axes[i].plot(x, 
+                                 data, 
+                                 linestyle=next(linecycler),
+                                 label=labels[j],
+                                 c=color,
+                                 lw=lw)
+                    axes[i].set_title(titles[i])
+                    if j < len(fill_between[i]):
+                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                for event in event_lookup.keys():
+                    axes[i].axvline(x=event_lookup[event], 
+                                    color=next(colorcycler), 
+                                    label=event, 
+                                    ls=next(linecycler),
+                                    lw=lw)
+            
+                if ylim[0] != None and y_lim_exception != i:
+                    axes[i].set_ylim(ylim)
+                
+                if y_log_scale:
+                    plt.yscale('log')
+        else:
+            for i in range(shape[0]):
+                for j in range(shape[1]):
+                    if (i, j) == free_axis:
+                        axes[i, j].axis('off')
+                    else:
+                        linecycler = cycle(self.__lines_styles)
+                        colorcycler = cycle(self.__colors)
+                        if plot_idx < len(y):
+                            for k, array in enumerate(y[plot_idx]):
+                                if torch.is_tensor(array):
+                                    data = array.cpu().detach().numpy()
+                                else:
+                                    data = array
+                                space = int(len(x) / number_xlabels)
+                                axes[i, j].xaxis.set_major_locator(ticker.MultipleLocator(space))
+                                if len(x) > len(data):
+                                    c = len(data)
+                                else:
+                                    c = len(x)
+                                axes[i, j].plot(x[:c], 
+                                                data, 
+                                                label=labels[k], 
+                                                c=next(colorcycler), 
+                                                lw=lw, 
+                                                linestyle=next(linecycler))
+                            axes[i, j].set_title(titles[plot_idx])
+                            if ylim[0] != None:
+                                axes[i, j].set_ylim(ylim)
+                        plot_idx += 1
+                    if y_log_scale:
+                        plt.yscale('log')
+        if 1 in shape:
+            lines, labels = axes[0].get_legend_handles_labels()
+        else:
+            lines, labels = axes[0, 0].get_legend_handles_labels()
+        fig.legend(lines, labels, loc='upper center', ncol=number_of_legend_columns, bbox_to_anchor=legend_loc)
+
+        for ax in axes.flat:
+            ax.set(xlabel=xlabel, ylabel=ylabel)
+
+        # Hide x labels and tick labels for top plots and y ticks for right plots.
+        if plot_all_labels:
+            for ax in axes.flat:
+                ax.label_outer()
+
+        # Adjust layout to prevent overlap
+        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+    def scatter(self, 
+                x, 
+                y:list, 
+                labels:list, 
+                figure_shape:tuple, 
+                file_name:str, 
+                title:str, 
+                std=[], 
+                true_values=[],
+                true_label='true',
+                plot_legend=True, 
+                legend_loc='best',
+                xlabel='',
+                ylabel='', 
+                xlabel_rotation=None):
+        assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
+        fig = self.__generate_figure(shape=figure_shape)
+
+        ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+        ax.set_facecolor('xkcd:white')
+        ax.set_title(title)
+
+        if torch.is_tensor(x):
+            x = x.cpu().detach().numpy()
+
+        markercycler = cycle(self.__marker_styles)
+        for i, array in enumerate(y):
+            
+            if torch.is_tensor(array):
+                data = array.cpu().detach().numpy()
+            else:
+                data = array
+
+            if i >= len(std):
+                ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+            if i < len(std):
+                ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
+        
+        linecycler = cycle(self.__lines_styles)
+        for i, true_value in enumerate(true_values):
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+            
+        if plot_legend:
+            plt.legend(loc=legend_loc)
+
+        if xlabel != '':
+            plt.xlabel(xlabel, )
+
+        if ylabel != '':
+            plt.ylabel(ylabel)
+
+        if xlabel_rotation != None:
+            plt.xticks(rotation=45, ha='right')
+
+        plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
+
+        plt.close()
 
     def animate(self, name: str):
         """Builds animation from images saved in self.frames. Then saves animation as gif.

+ 55 - 80
src/preprocessing/synthetic_data.py

@@ -29,7 +29,7 @@ class SyntheticDeseaseData:
         """
         self.generated = True
 
-    def plot(self, labels: tuple, title:str):
+    def plot(self, labels: tuple, title:str, file_name:str):
         """Plot the data which was generated.
 
         Args:
@@ -38,31 +38,31 @@ class SyntheticDeseaseData:
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
         else: 
             print('Data has to be generated before plotting!')
 
-class SI(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
+class SIR(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+        """This class is able to generate synthetic data for the SIR model.
 
         Args:
             plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
+            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
-
         self.N = N
-        self.S_0 = N - I_0
+        self.S_0 = N - I_0 - R_0
         self.I_0 = I_0
+        self.R_0 = R_0
 
         self.alpha = alpha
         self.beta = beta
-
         super().__init__(simulation_time, time_points, plotter)
 
     def differential_eq(self, y, t, alpha, beta):
@@ -77,33 +77,34 @@ class SI(SyntheticDeseaseData):
         Returns:
             tuple: Change amount for each group.
         """
-        S, I = y
-        dSdt = -self.beta * ((S * I) / self.N)
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I
-        return dSdt, dIdt
-    
+        S, I, _ = y
+        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dRdt = self.alpha * I
+        return dSdt, dIdt, dRdt
+
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """
-        y_0 = self.S_0, self.I_0
+        y_0 = self.S_0, self.I_0, self.R_0
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name='SIR_plot'):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious'), title=title)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
 
     def save(self, name=''):
         if self.generated:
             COVID_Data = np.asarray([self.t, *self.data]) 
 
-            np.savetxt('datasets/SI_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
         else: 
             print('Data has to be generated before plotting!')
 
-class I(SI):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+class I(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
         Args:
@@ -115,94 +116,68 @@ class I(SI):
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
-
-        super().__init__(plotter, N=N, I_0=I_0, simulation_time=simulation_time, time_points=time_points, alpha=alpha, beta=beta)
-
-    def generate(self):
-        """This funtion generates the data for this configuration of the SIR model.
-        """
-        super().generate()
-        self.data = self.data[1]
-        print(self.data.shape)
-
-    def plot(self, title=''):
-        """Plot the data which was generated.
-        """
-        if self.generated:
-            self.plotter.plot(self.t, [self.data], ['Infectious'], title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
-        else: 
-            print('Data has to be generated before plotting!')
-
-    def save(self, name=''):
-        if self.generated:
-            COVID_Data = np.asarray([self.t, self.data]) 
-
-            np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
-        else: 
-            print('Data has to be generated before plotting!')
-
-
-class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data for the SIR model.
-
-        Args:
-            plotter (Plotter): Plotter object to plot dataset curves.
-            N (int, optional): Size of the population. Defaults to 59e6.
-            I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
-            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
-            simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
-            time_points (int, optional): Number of time sample points. Defaults to 100.
-            alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
-            beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
-        """
         self.N = N
-        self.S_0 = N - I_0 - R_0
+        self.C = C
         self.I_0 = I_0
-        self.R_0 = R_0
-
+ 
         self.alpha = alpha
-        self.beta = beta
 
-        super().__init__(simulation_time, time_points, plotter)
+        self.t = np.linspace(0, 1, time_points)
+        self.t_save = np.linspace(1, time_points, time_points)
+        self.t_f = time_points
+        self.reproduction_value = []
+        self.data = None
+        self.generated = False
+        self.plotter = plotter
+        
+    def R_t(self, t):
+        descaled_t = t * self.t_f
+        # if descaled_t < threshold1:
+        return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
-    def differential_eq(self, y, t, alpha, beta):
+
+            
+    def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
 
         Args:
             y (tuple): Vector that holds the current state of the three groups.
             t (_): not used
-            alpha (_): not used
-            beta (_): not used
 
         Returns:
             tuple: Change amount for each group.
         """
-        S, I, R = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
-        dRdt = self.alpha * I
-        return dSdt, dIdt, dRdt
+        dIdt = self.alpha * self.t_f * (self.R_t(t) - 1) * I
+        return dIdt
 
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """
-        y_0 = self.S_0, self.I_0, self.R_0
-        self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
-        super().generate()
+        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = self.data[0] * self.C
+        self.t_counter = 0
+        self.generated =True
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name=''):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
+        if self.generated:
+            t = np.linspace(0, len(self.t), len(self.t))
+            self.plotter.plot(t, [self.data], ['Infectious'], file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            for time in self.t:
+                self.reproduction_value.append(self.R_t(time))
+            self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
+        else: 
+            print('Data has to be generated before plotting!')
 
     def save(self, name=''):
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data]) 
 
-            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
         else: 
             print('Data has to be generated before plotting!')
+
         
 
 class SIDR(SyntheticDeseaseData):

+ 88 - 138
src/preprocessing/transform_data.py

@@ -3,7 +3,24 @@ import pandas as pd
 
 from src.plotter import Plotter
 
-def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-Todesfaelle_in_Deutschland/', plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein' : (1, 2897000),
+                'Hamburg' : (2, 1841000), 
+                'Niedersachsen' : (3, 7982000), 
+                'Bremen' : (4, 569352),
+                'Nordrhein-Westfalen' : (5, 17930000),
+                'Hessen' : (6, 6266000),
+                'Rheinland-Pfalz' : (7, 4085000),
+                'Baden-Württemberg' : (8, 11070000),
+                'Bayern' : (9, 13080000),
+                'Saarland' : (10, 990509),
+                'Berlin' : (11, 3645000),
+                'Brandenburg' : (12, 2641000),
+                'Mecklenburg-Vorpommern' : (13, 1610000),
+                'Sachsen' : (14, 4078000),
+                'Sachsen-Anhalt' : (15, 2208000),
+                'Thüringen' : (16, 2143000)}
+
+def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
     Args:
@@ -18,147 +35,80 @@ def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-To
         plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     # read the data
-    df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
-
-    df = df.drop(df.index[1200:])
-    
-    # population of germany at the end of 2019
-    N = 83100000
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
-
-    # S_0 = N - I_0
-    S[0] = N - df['Faelle_gesamt'][0]
-    # I_0 = overall cases at the day - overall death cases at the day
-    I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
-    # R_0 = overall death cases at the day
-    R[0] = df['Todesfaelle_gesamt'][0]
-
-    # the recovery time is 14 days
-    recovery_queue = np.zeros(14)
-    
-    for day in range(1, df.shape[0]):
-        infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
-        deaths = df['Todesfaelle_neu'][day]
-        recoveries = recovery_queue[0]
-
-        S[day] = S[day-1] - infections
-        I[day] = I[day-1] + infections - deaths - recoveries
-        R[day] = R[day-1] + deaths + recoveries
-
-        # update recovery queue
-        if I[day] < 0:
-            recovery_queue[-1] -= I[day] 
-            I[day] = 0
-
-        recovery_queue[:-1] = recovery_queue[1:]
-        recovery_queue[-1] = infections
-
-    t = np.arange(0, df.shape[0], 1)
-    if plotter != None:
-        # plot graphs
-        plots = []
-        labels = []
-
-        if 'S' not in exclude:
-            plots.append(S)
-            labels.append('S')
-        
-        if 'I' not in exclude:
-            plots.append(I)
-            labels.append('I')
-
-        if 'R' not in exclude:
-            plots.append(R)
-            labels.append('R')
-
-        plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
-
-    COVID_Data = np.asarray([t[0::sample_rate], 
-                             S[0::sample_rate], 
-                             I[0::sample_rate], 
-                             R[0::sample_rate]]) 
-
-    np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")
-
 
 
-def get_state_cases(county_id, state_id):
-    id = county_id // 1000
-    return id == state_id
-
-def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, time_range=1200, sample_rate=1, dataset_path='datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'):
-    """Transforms the RKI infection cases dataset to a SIR dataset.
-
-    Args:
-        plotter (Plotter): Plotter object to plot dataset curves.
-        state_name (str): Name of the state that is to be singled out in the new dataset.
-        time_range (int, optional): Number of days that will be looked at in the new dataset. Defaults to 1200.
-        sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
-        dataset_path (str, optional): Path to the CSV file, where the data is stored. Defaults to 'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'.
-    """
-    df = pd.read_csv(dataset_path)
-
-    state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                    'Hamburg' : (2, 1841000), 
-                    'Niedersachsen' : (3, 7982000), 
-                    'Bremen' : (4, 569352),
-                    'Nordrhein-Westfalen' : (5, 17930000),
-                    'Hessen' : (6, 6266000),
-                    'Rheinland-Pfalz' : (7, 4085000),
-                    'Baden-Württemberg' : (8, 11070000),
-                    'Bayern' : (9, 13080000),
-                    'Saarland' : (10, 990509),
-                    'Berlin' : (11, 3645000),
-                    'Brandenburg' : (12, 2641000),
-                    'Mecklenburg-Vorpommern' : (13, 1610000),
-                    'Sachsen' : (14, 4078000),
-                    'Sachsen-Anhalt' : (15, 2208000),
-                    'Thüringen' : (16, 2143000)}
-    state_ID, N = state_lookup[state_name]
-
-    # single out a state
-    state_IDs = df['IdLandkreis'] // 1000
-    state_df = df.loc[state_IDs == state_ID]
-
-    # sort entries by state
-    state_df = state_df.sort_values('Refdatum')
-    state_df = state_df.reset_index(drop=True)
-
-
-    # collect cases    
     infections = np.zeros(time_range)
-    dead = np.zeros(time_range)
-    recovered = np.zeros(time_range)
-    entry_idx = 0
-    day = 0
-    date = state_df['Refdatum'][entry_idx]
-    # check for each date all entries
-    while day < time_range:
-        # use the date sorted characteristic and take all entries with current date
-        while state_df['Refdatum'][entry_idx] == date:
-            # TODO use further parameters
-            infections[day] += state_df['AnzahlFall'][entry_idx]
-            dead[day] += state_df['AnzahlTodesfall'][entry_idx]
-            recovered[day] += state_df['AnzahlGenesen'][entry_idx]
-            entry_idx += 1
-        # move day index by difference between the current and next date
-        day += (pd.to_datetime(state_df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
-        date = state_df['Refdatum'][entry_idx]
-
-    S = np.zeros(time_range)
-    I = np.zeros(time_range)
-    R = np.zeros(time_range)
-
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    if state_name == 'Germany':
+        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        N = 83100000
+        infections[0] = df['Faelle_gesamt'][0]
+        deaths[0] = df['Todesfaelle_neu'][0]
+
+        recovery_queue = np.zeros(14)
+        for i in range(1, time_range):
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            deaths[i] = df['Todesfaelle_neu'][i]
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+    else:
+        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        state_ID, N = state_lookup[state_name]
+
+        # single out a state
+        state_IDs = df['IdLandkreis'] // 1000
+        df = df.loc[state_IDs == state_ID]
+
+        # sort entries by state
+        df = df.sort_values('Refdatum')
+        df = df.reset_index(drop=True)
+
+        # collect cases    
+        entry_idx = 0
+        day = 0
+        date = df['Refdatum'][entry_idx]
+        # check for each date all entries
+        while day < time_range:
+            # use the date sorted characteristic and take all entries with current date
+            while df['Refdatum'][entry_idx] == date:
+                infections[day] += df['AnzahlFall'][entry_idx]
+                deaths[day] += df['AnzahlTodesfall'][entry_idx]
+                entry_idx += 1
+            # move day index by difference between the current and next date
+            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            date = df['Refdatum'][entry_idx]
+
+        recovery_queue = np.zeros(14)
+        week_counter = 2
+        for i in range(1, time_range):
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+            week_counter -= 1
+        
+    df = df.drop(df.index[time_range:])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     S[0] = N - infections[0]
     I[0] = infections[0]
     R[0] = 0
-
-    for day in range(1, time_range):
-        S[day] = S[day-1] - infections[day]
-        I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-        R[day] = R[day-1] + I[day-1] * alpha
-
+    if model == 'I':
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
+            R[day] = R[day-1] + I[day-1] * alpha
+    else:
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
+            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            if I[day] < 0:
+                I[day] = 0
+    
     t = np.arange(0, time_range, 1)
 
     # select, which group is to be outputted
@@ -175,12 +125,12 @@ def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, t
     plotter.plot(t, 
                  groups, 
                  [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue'), 
-                 state_name +' SI', 
+                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
+                 state_name, 
                  (6,6), 
                  xlabel='time / days', 
                  ylabel='amount of people')
 
     COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")

+ 9 - 8
src/problem.py

@@ -58,17 +58,18 @@ class ReducedSIRProblem(PandemicProblem):
         super().__init__(data)
         self.alpha = alpha
 
-    def residual(self, SI_pred):
+    def residual(self, I_pred):
         super().residual()
 
-        SI_pred.backward(self._gradients[0], retain_graph=True)
-        dIdt = self._data.t_raw.grad.clone()
-        self._data.t_raw.grad.zero_()
-        
-        I = SI_pred[:, 0]
-        R_t = SI_pred[:, 1]
+        I_pred.backward(self._gradients[0], retain_graph=True)
+        dIdt = self._data.t_scaled.grad.clone()
+        self._data.t_scaled.grad.zero_()
 
-        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
+        I = I_pred[:, 0]
+        R_t = I_pred[:, 1]
+
+        # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0]
 
+        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         return I_residual
 

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 52 - 1576
state_data_dinn_sir.ipynb


+ 118 - 0
states_training.py

@@ -0,0 +1,118 @@
+import torch
+import numpy as np
+import csv
+import sys
+
+from src.dataset import PandemicDataset, Norms
+from src.problem import ReducedSIRProblem
+from src.dinn import DINN, Scheduler, Activation
+
+ALPHA = [1/14, 1/5]
+DO_STATES = True
+DO_SYNTHETIC = False
+
+ITERATIONS = 13
+
+state_starting_index = 0
+
+if "1" in sys.argv:
+     state_starting_index = 8
+
+STATE_LOOKUP = {'Schleswig_Holstein' : 2897000,
+                'Hamburg' : 1841000, 
+                'Niedersachsen' : 7982000, 
+                'Bremen' : 569352,
+                'Nordrhein_Westfalen' : 17930000,
+                'Hessen' : 6266000,
+                'Rheinland_Pfalz' : 4085000,
+                'Baden_Wuerttemberg' : 11070000,
+                'Bayern' : 13080000,
+                'Saarland' : 990509,
+                'Berlin' : 3645000,
+                'Brandenburg' : 2641000,
+                'Mecklenburg_Vorpommern' : 1610000,
+                'Sachsen' : 4078000,
+                'Sachsen_Anhalt' : 2208000,
+                'Thueringen' : 2143000}
+
+if DO_SYNTHETIC:
+    alpha = 1/3
+    covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
+    for i in range(ITERATIONS):
+        dataset = PandemicDataset('Synthetic I', 
+                                    ['I'], 
+                                    7.6e6, 
+                                    *covid_data, 
+                                    norm_name=Norms.CONSTANT, 
+                                    use_scaled_time=True)
+
+        problem = ReducedSIRProblem(dataset, alpha)
+        dinn = DINN(2, 
+                    dataset, 
+                    [], 
+                    problem, 
+                    None, 
+                    state_variables=['R_t'], 
+                    hidden_size=100, 
+                    hidden_layers=4, 
+                    activation_layer=torch.nn.Tanh(),
+                    activation_output=Activation.POWER)
+
+        dinn.configure_training(1e-3, 
+                                20000, 
+                                scheduler_class=Scheduler.POLYNOMIAL, 
+                                lambda_physics=1e-6,
+                                verbose=True)
+        dinn.train(verbose=True, do_split_training=True)
+
+        dinn.save_training_process(f'synthetic_{i}')
+        #r_t = dinn.get_output(1).detach().cpu().numpy()
+
+        #with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
+                #writer = csv.writer(csvfile, delimiter=',')
+                #writer.writerow(r_t)
+
+for iteration in range(ITERATIONS):
+    if iteration <= 2:
+         print('skip first three iteration, as it was already done')
+         continue
+    if DO_STATES:
+        for state_idx in range(state_starting_index, state_starting_index + 8):
+            state = list(STATE_LOOKUP.keys())[state_idx]
+            exclude = ['Schleswig_Holstein', 'Hamburg', 'Niedersachsen']
+            if iteration == 3 and state in exclude:
+                print(f'skip in {state} third iteration, as it was already done')
+                continue
+            for i, alpha in enumerate(ALPHA):
+                print(f'training for {state} ({state_idx}), alpha: {alpha}, iter: {iteration}')
+
+                covid_data = np.genfromtxt(f'./datasets/I_RKI_{state}_1_{int(1/alpha)}.csv', delimiter=',')
+                dataset = PandemicDataset(state, ['I'], STATE_LOOKUP[state], *covid_data, norm_name=Norms.CONSTANT, use_scaled_time=True)
+
+                problem = ReducedSIRProblem(dataset, alpha)
+
+                dinn = DINN(2, 
+                            dataset, 
+                            [], 
+                            problem, 
+                            None, 
+                            state_variables=['R_t'], 
+                            hidden_size=100, 
+                            hidden_layers=4, 
+                            activation_layer=torch.nn.Tanh(),
+                            activation_output=Activation.POWER)
+
+                dinn.configure_training(1e-3, 
+                                        25000, 
+                                        scheduler_class=Scheduler.POLYNOMIAL,
+                                        lambda_obs=1e2, 
+                                        lambda_physics=1e-6, 
+                                        verbose=True)
+                dinn.train(verbose=True, do_split_training=True)
+
+                dinn.save_training_process(f'{state}_{i}_{iteration}')
+
+                r_t = dinn.get_output(1).detach().cpu().numpy()
+                with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(r_t)

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 257 - 78
synth_dinn_reduced_sir.ipynb


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 472 - 90
synth_dinn_sir.ipynb


Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů