3 Revize db75cf404e ... f2ced484b6

Autor SHA1 Zpráva Datum
  phillip.rothenbeck f2ced484b6 reformat před 1 rokem
  phillip.rothenbeck c97674cbe1 set alpha sir problem před 1 rokem
  phillip.rothenbeck 84e7847058 get data from paper před 1 rokem
6 změnil soubory, kde provedl 400 přidání a 236 odebrání
  1. 32 30
      src/dataset.py
  2. 71 70
      src/plotter.py
  3. 33 27
      src/preprocessing/synthetic_data.py
  4. 176 49
      src/preprocessing/transform_data.py
  5. 32 4
      src/problem.py
  6. 56 56
      states_training.py

+ 32 - 30
src/dataset.py

@@ -1,20 +1,23 @@
 import torch
 import torch
+import numpy as np
 from enum import Enum
 from enum import Enum
 
 
+
 class Norms(Enum):
 class Norms(Enum):
-    POPULATION=0
-    MIN_MAX=1
-    CONSTANT=2
+    POPULATION = 0
+    MIN_MAX = 1
+    CONSTANT = 2
+
 
 
 class PandemicDataset:
 class PandemicDataset:
-    def __init__(self, 
-                 name:str,
-                 group_names:list, 
-                 N: int, 
-                 t, 
-                 *groups, 
+    def __init__(self,
+                 name: str,
+                 group_names: list,
+                 N: int,
+                 t,
+                 *groups,
                  norm_name=Norms.MIN_MAX,
                  norm_name=Norms.MIN_MAX,
-                 C = 10**5,
+                 C=10**5,
                  use_scaled_time=False):
                  use_scaled_time=False):
         """Class to hold all data for one training process.
         """Class to hold all data for one training process.
 
 
@@ -60,12 +63,12 @@ class PandemicDataset:
 
 
         self.__group_dict = {}
         self.__group_dict = {}
         for i, name in enumerate(group_names):
         for i, name in enumerate(group_names):
-            self.__group_dict.update({name : i})
+            self.__group_dict.update({name: i})
 
 
         self.__group_names = group_names
         self.__group_names = group_names
 
 
         self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
         self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
-        
+
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__norms = self.__norm(self.__groups)
         self.__norms = self.__norm(self.__groups)
@@ -73,50 +76,49 @@ class PandemicDataset:
     @property
     @property
     def number_groups(self):
     def number_groups(self):
         return len(self.__group_names)
         return len(self.__group_names)
-    
+
     @property
     @property
     def data(self):
     def data(self):
         return self.__groups
         return self.__groups
-    
+
     @property
     @property
-    def group_names(self):
+    def group_names(self) -> np.ndarray:
         return self.__group_names
         return self.__group_names
-    
+
     def __population_norm(self, data):
     def __population_norm(self, data):
         return [(data[i] / self.N) for i in range(self.number_groups)]
         return [(data[i] / self.N) for i in range(self.number_groups)]
-    
+
     def __population_denorm(self, data):
     def __population_denorm(self, data):
         return [(data[i] * self.N) for i in range(self.number_groups)]
         return [(data[i] * self.N) for i in range(self.number_groups)]
 
 
     def __min_max_norm(self, data):
     def __min_max_norm(self, data):
         return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
         return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
-    
+
     def __min_max_denorm(self, data):
     def __min_max_denorm(self, data):
         return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
         return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
-    
+
     def __constant_norm(self, data):
     def __constant_norm(self, data):
         return [(data[i] / self.C) for i in range(self.number_groups)]
         return [(data[i] / self.C) for i in range(self.number_groups)]
 
 
     def __constant_denorm(self, data):
     def __constant_denorm(self, data):
         return [(data[i] * self.C) for i in range(self.number_groups)]
         return [(data[i] * self.C) for i in range(self.number_groups)]
 
 
-    def get_normalized_data(self, data:list):
+    def get_normalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__norm(data)
         return self.__norm(data)
-    
-    def get_denormalized_data(self, data:list):
+
+    def get_denormalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__denorm(data)
         return self.__denorm(data)
 
 
-    def get_group(self, name:str):
+    def get_group(self, name: str):
         return self.__groups[self.__group_dict[name]]
         return self.__groups[self.__group_dict[name]]
-    
-    def get_min(self, name:str):
+
+    def get_min(self, name: str):
         return self.__mins[self.__group_dict[name]]
         return self.__mins[self.__group_dict[name]]
-    
-    def get_max(self, name:str):
+
+    def get_max(self, name: str):
         return self.__maxs[self.__group_dict[name]]
         return self.__maxs[self.__group_dict[name]]
-    
-    def get_norm(self, name:str):
+
+    def get_norm(self, name: str):
         return self.__norms[self.__group_dict[name]]
         return self.__norms[self.__group_dict[name]]
-    

+ 71 - 70
src/plotter.py

@@ -3,7 +3,7 @@ import torch
 import imageio
 import imageio
 import numpy as np
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
-import matplotlib.ticker as ticker 
+import matplotlib.ticker as ticker
 
 
 from matplotlib import rcParams
 from matplotlib import rcParams
 from itertools import cycle
 from itertools import cycle
@@ -15,6 +15,7 @@ SUSCEPTIBLE_COLOR = '#6399f7'
 INFECTIOUS_COLOR = '#f56262'
 INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 REMOVED_COLOR = '#83eb5e'
 
 
+
 class Plotter:
 class Plotter:
     def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
     def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
         """Plotter of scientific plots and animations, for dinn.py.
@@ -65,27 +66,27 @@ class Plotter:
         """Delete all frames that were saved.
         """Delete all frames that were saved.
         """
         """
         self.__frames = []
         self.__frames = []
-    
-    #TODO comments
-    def plot(self, 
-             x, 
-             y:list, 
-             labels:list,
-             file_name:str,
-             title:str, 
-             figure_shape:tuple, 
+
+    # TODO comments
+    def plot(self,
+             x,
+             y: list,
+             labels: list,
+             file_name: str,
+             title: str,
+             figure_shape: tuple,
              event_lookup={},
              event_lookup={},
-             is_frame=False, 
-             is_background=[], 
+             is_frame=False,
+             is_background=[],
              fill_between=[],
              fill_between=[],
-             plot_legend=True, 
-             y_log_scale=False, 
-             lw=3, 
-             legend_loc='best', 
+             plot_legend=True,
+             y_log_scale=False,
+             lw=3,
+             legend_loc='best',
              ylim=(None, None),
              ylim=(None, None),
-             number_xlabels = 5,
+             number_xlabels=5,
              xlabel='',
              xlabel='',
-             ylabel='', 
+             ylabel='',
              xlabel_rotation=None):
              xlabel_rotation=None):
         """Plotting method.
         """Plotting method.
 
 
@@ -112,14 +113,14 @@ class Plotter:
 
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
         ax.set_facecolor('xkcd:white')
-        #ax.yaxis.set_tick_params(length=0, which='both')
-        #ax.xaxis.set_tick_params(length=0, which='both')
+        # ax.yaxis.set_tick_params(length=0, which='both')
+        # ax.xaxis.set_tick_params(length=0, which='both')
 
 
-        #ax.grid(which='major', c='black', lw=0.2, ls='-')
+        # ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
         ax.set_title(title)
 
 
-        #for spine in ('top', 'right', 'bottom', 'left'):
-         #   ax.spines[spine].set_visible(False)
+        # for spine in ('top', 'right', 'bottom', 'left'):
+        #   ax.spines[spine].set_visible(False)
 
 
         if torch.is_tensor(x):
         if torch.is_tensor(x):
             x = x.cpu().detach().numpy()
             x = x.cpu().detach().numpy()
@@ -131,24 +132,24 @@ class Plotter:
             if len(is_background) != 0:
             if len(is_background) != 0:
                 if is_background[i]:
                 if is_background[i]:
                     alpha = 0.25
                     alpha = 0.25
-            
+
             if torch.is_tensor(array):
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
                 data = array.cpu().detach().numpy()
             else:
             else:
                 data = array
                 data = array
 
 
             space = int(len(x) / number_xlabels)
             space = int(len(x) / number_xlabels)
-            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
-        
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
+
             ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
             ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
             if i < len(fill_between):
             if i < len(fill_between):
-                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+                ax.fill_between(x, data + fill_between[i], data - fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
             j = i
             j = i
 
 
         for event in event_lookup.keys():
         for event in event_lookup.keys():
             j += 1
             j += 1
-            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
-            
+            plt.axvline(x=event_lookup[event], color=self.__colors[j % len(self.__colors)], label=event, ls=next(linecycler))
+
         if plot_legend:
         if plot_legend:
             plt.legend(loc=legend_loc)
             plt.legend(loc=legend_loc)
 
 
@@ -166,7 +167,7 @@ class Plotter:
 
 
         if xlabel_rotation != None:
         if xlabel_rotation != None:
             plt.xticks(rotation=xlabel_rotation)
             plt.xticks(rotation=xlabel_rotation)
- 
+
         if not os.path.exists(FRAME_DIR):
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
             os.makedirs(FRAME_DIR)
 
 
@@ -181,29 +182,29 @@ class Plotter:
 
 
         plt.close()
         plt.close()
 
 
-    def cluster_plot(self, 
-                     x, 
-                     y:list, 
-                     labels, 
-                     shape, 
-                     plots_shape, 
-                     file_name:str, 
-                     titles:list, 
+    def cluster_plot(self,
+                     x,
+                     y: list,
+                     labels,
+                     shape,
+                     plots_shape,
+                     file_name: str,
+                     titles: list,
                      number_xlabels=5,
                      number_xlabels=5,
-                     lw=3, 
+                     lw=3,
                      fill_between=[],
                      fill_between=[],
                      event_lookup={},
                      event_lookup={},
                      xlabel='',
                      xlabel='',
                      ylabel='',
                      ylabel='',
                      ylim=(None, None),
                      ylim=(None, None),
                      y_lim_exception=None,
                      y_lim_exception=None,
-                     y_log_scale=False, 
-                     legend_loc=(0.5,0.992),
+                     y_log_scale=False,
+                     legend_loc=(0.5, 0.992),
                      add_y_space=0.05,
                      add_y_space=0.05,
                      number_of_legend_columns=1,
                      number_of_legend_columns=1,
                      same_axes=True,
                      same_axes=True,
                      free_axis=(None, None),
                      free_axis=(None, None),
-                     plot_all_labels = True):
+                     plot_all_labels=True):
         real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
         real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
         fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
         fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
         plot_idx = 0
         plot_idx = 0
@@ -224,25 +225,25 @@ class Plotter:
 
 
                     space = int(len(x) / number_xlabels)
                     space = int(len(x) / number_xlabels)
                     axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
                     axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
-                    axes[i].plot(x, 
-                                 data, 
+                    axes[i].plot(x,
+                                 data,
                                  linestyle=next(linecycler),
                                  linestyle=next(linecycler),
                                  label=labels[j],
                                  label=labels[j],
                                  c=color,
                                  c=color,
                                  lw=lw)
                                  lw=lw)
                     axes[i].set_title(titles[i])
                     axes[i].set_title(titles[i])
                     if j < len(fill_between[i]):
                     if j < len(fill_between[i]):
-                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                        axes[i].fill_between(x, data + fill_between[i][j], data - fill_between[i][j], facecolor=color, alpha=0.5)
                 for event in event_lookup.keys():
                 for event in event_lookup.keys():
-                    axes[i].axvline(x=event_lookup[event], 
-                                    color=next(colorcycler), 
-                                    label=event, 
+                    axes[i].axvline(x=event_lookup[event],
+                                    color=next(colorcycler),
+                                    label=event,
                                     ls=next(linecycler),
                                     ls=next(linecycler),
                                     lw=lw)
                                     lw=lw)
-            
+
                 if ylim[0] != None and y_lim_exception != i:
                 if ylim[0] != None and y_lim_exception != i:
                     axes[i].set_ylim(ylim)
                     axes[i].set_ylim(ylim)
-                
+
                 if y_log_scale:
                 if y_log_scale:
                     plt.yscale('log')
                     plt.yscale('log')
         else:
         else:
@@ -265,11 +266,11 @@ class Plotter:
                                     c = len(data)
                                     c = len(data)
                                 else:
                                 else:
                                     c = len(x)
                                     c = len(x)
-                                axes[i, j].plot(x[:c], 
-                                                data, 
-                                                label=labels[k], 
-                                                c=next(colorcycler), 
-                                                lw=lw, 
+                                axes[i, j].plot(x[:c],
+                                                data,
+                                                label=labels[k],
+                                                c=next(colorcycler),
+                                                lw=lw,
                                                 linestyle=next(linecycler))
                                                 linestyle=next(linecycler))
                             axes[i, j].set_title(titles[plot_idx])
                             axes[i, j].set_title(titles[plot_idx])
                             if ylim[0] != None:
                             if ylim[0] != None:
@@ -292,23 +293,23 @@ class Plotter:
                 ax.label_outer()
                 ax.label_outer()
 
 
         # Adjust layout to prevent overlap
         # Adjust layout to prevent overlap
-        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.tight_layout(rect=[0, 0, 1, 1 - add_y_space])
         plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
         plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
 
 
-    def scatter(self, 
-                x, 
-                y:list, 
-                labels:list, 
-                figure_shape:tuple, 
-                file_name:str, 
-                title:str, 
-                std=[], 
+    def scatter(self,
+                x,
+                y: list,
+                labels: list,
+                figure_shape: tuple,
+                file_name: str,
+                title: str,
+                std=[],
                 true_values=[],
                 true_values=[],
                 true_label='true',
                 true_label='true',
-                plot_legend=True, 
+                plot_legend=True,
                 legend_loc='best',
                 legend_loc='best',
                 xlabel='',
                 xlabel='',
-                ylabel='', 
+                ylabel='',
                 xlabel_rotation=None):
                 xlabel_rotation=None):
         assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         fig = self.__generate_figure(shape=figure_shape)
         fig = self.__generate_figure(shape=figure_shape)
@@ -322,7 +323,7 @@ class Plotter:
 
 
         markercycler = cycle(self.__marker_styles)
         markercycler = cycle(self.__marker_styles)
         for i, array in enumerate(y):
         for i, array in enumerate(y):
-            
+
             if torch.is_tensor(array):
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
                 data = array.cpu().detach().numpy()
             else:
             else:
@@ -332,11 +333,11 @@ class Plotter:
                 ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
                 ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
             if i < len(std):
             if i < len(std):
                 ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
                 ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
-        
+
         linecycler = cycle(self.__lines_styles)
         linecycler = cycle(self.__lines_styles)
         for i, true_value in enumerate(true_values):
         for i, true_value in enumerate(true_values):
-            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
-            
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}', c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+
         if plot_legend:
         if plot_legend:
             plt.legend(loc=legend_loc)
             plt.legend(loc=legend_loc)
 
 

+ 33 - 27
src/preprocessing/synthetic_data.py

@@ -5,8 +5,9 @@ from scipy.integrate import odeint
 
 
 from src.plotter import Plotter
 from src.plotter import Plotter
 
 
+
 class SyntheticDeseaseData:
 class SyntheticDeseaseData:
-    def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
+    def __init__(self, simulation_time: int, time_points: int, plotter: Plotter):
         """This class is the parent class for every class, that is supposed to generate synthetic data.
         """This class is the parent class for every class, that is supposed to generate synthetic data.
 
 
         Args:
         Args:
@@ -29,7 +30,7 @@ class SyntheticDeseaseData:
         """
         """
         self.generated = True
         self.generated = True
 
 
-    def plot(self, labels: tuple, title:str, file_name:str):
+    def plot(self, labels: tuple, title: str, file_name: str, leave_out_indices):
         """Plot the data which was generated.
         """Plot the data which was generated.
 
 
         Args:
         Args:
@@ -37,13 +38,20 @@ class SyntheticDeseaseData:
             title (str): The name of the plot.
             title (str): The name of the plot.
         """
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
+        groups = []
+        used_labels = []
+        for i, group in enumerate(self.data):
+            if not i in leave_out_indices:
+                groups.append(group)
+                used_labels.append(labels[i])
         if self.generated:
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
-        else: 
+            self.plotter.plot(self.t, groups, used_labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
+
 class SIR(SyntheticDeseaseData):
 class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
         """This class is able to generate synthetic data for the SIR model.
         """This class is able to generate synthetic data for the SIR model.
 
 
         Args:
         Args:
@@ -78,8 +86,8 @@ class SIR(SyntheticDeseaseData):
             tuple: Change amount for each group.
             tuple: Change amount for each group.
         """
         """
         S, I, _ = y
         S, I, _ = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dSdt = -self.beta * ((S * I) / self.N)  # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I  # self.beta * S * I - self.alpha * I
         dRdt = self.alpha * I
         dRdt = self.alpha * I
         return dSdt, dIdt, dRdt
         return dSdt, dIdt, dRdt
 
 
@@ -90,21 +98,22 @@ class SIR(SyntheticDeseaseData):
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
         super().generate()
 
 
-    def plot(self, title='', file_name='SIR_plot'):
+    def plot(self, title='', file_name='SIR_plot', leave_out_indices=[]):
         """Plot the data which was generated.
         """Plot the data which was generated.
         """
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name, leave_out_indices=leave_out_indices)
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
 
             np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
+
 class I(SyntheticDeseaseData):
 class I(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
+    def __init__(self, plotter: Plotter, N: int, C: int, I_0=1, time_points=100, alpha=1 / 3) -> None:
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
 
         Args:
         Args:
@@ -119,7 +128,7 @@ class I(SyntheticDeseaseData):
         self.N = N
         self.N = N
         self.C = C
         self.C = C
         self.I_0 = I_0
         self.I_0 = I_0
- 
+
         self.alpha = alpha
         self.alpha = alpha
 
 
         self.t = np.linspace(0, 1, time_points)
         self.t = np.linspace(0, 1, time_points)
@@ -129,14 +138,12 @@ class I(SyntheticDeseaseData):
         self.data = None
         self.data = None
         self.generated = False
         self.generated = False
         self.plotter = plotter
         self.plotter = plotter
-        
+
     def R_t(self, t):
     def R_t(self, t):
         descaled_t = t * self.t_f
         descaled_t = t * self.t_f
         # if descaled_t < threshold1:
         # if descaled_t < threshold1:
         return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
         return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
 
-
-            
     def differential_eq(self, I, t):
     def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
         """In this function implements the differential equation of the SIR model will be implemented.
 
 
@@ -153,10 +160,10 @@ class I(SyntheticDeseaseData):
     def generate(self):
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """This funtion generates the data for this configuration of the SIR model.
         """
         """
-        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = odeint(self.differential_eq, self.I_0 / self.C, self.t).T
         self.data = self.data[0] * self.C
         self.data = self.data[0] * self.C
         self.t_counter = 0
         self.t_counter = 0
-        self.generated =True
+        self.generated = True
 
 
     def plot(self, title='', file_name=''):
     def plot(self, title='', file_name=''):
         """Plot the data which was generated.
         """Plot the data which was generated.
@@ -167,21 +174,20 @@ class I(SyntheticDeseaseData):
             for time in self.t:
             for time in self.t:
                 self.reproduction_value.append(self.R_t(time))
                 self.reproduction_value.append(self.R_t(time))
             self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
             self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t_save, self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data])
 
 
             np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
-        
 
 
 class SIDR(SyntheticDeseaseData):
 class SIDR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
         """This class is able to generate synthetic data for the SIDR model.
         """This class is able to generate synthetic data for the SIDR model.
 
 
         Args:
         Args:
@@ -207,7 +213,7 @@ class SIDR(SyntheticDeseaseData):
         self.gamma = gamma
         self.gamma = gamma
 
 
         super().__init__(simulation_time, time_points, plotter)
         super().__init__(simulation_time, time_points, plotter)
-    
+
     def differential_eq(self, y, t, alpha, beta, gamma):
     def differential_eq(self, y, t, alpha, beta, gamma):
         """In this function implements the differential equation of the SIDR model will be implemented.
         """In this function implements the differential equation of the SIDR model will be implemented.
 
 
@@ -223,7 +229,7 @@ class SIDR(SyntheticDeseaseData):
         """
         """
         S, I, D, R = y
         S, I, D, R = y
         dSdt = - (self.alpha / self.N) * S * I
         dSdt = - (self.alpha / self.N) * S * I
-        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I 
+        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
         dDdt = self.gamma * I
         dDdt = self.gamma * I
         dRdt = self.beta * I
         dRdt = self.beta * I
         return dSdt, dIdt, dDdt, dRdt
         return dSdt, dIdt, dDdt, dRdt
@@ -242,8 +248,8 @@ class SIDR(SyntheticDeseaseData):
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
 
             np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')

+ 176 - 49
src/preprocessing/transform_data.py

@@ -1,72 +1,124 @@
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
+from datetime import date, timedelta
 
 
 from src.plotter import Plotter
 from src.plotter import Plotter
 
 
-state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                'Hamburg' : (2, 1841000), 
-                'Niedersachsen' : (3, 7982000), 
-                'Bremen' : (4, 569352),
-                'Nordrhein-Westfalen' : (5, 17930000),
-                'Hessen' : (6, 6266000),
-                'Rheinland-Pfalz' : (7, 4085000),
-                'Baden-Württemberg' : (8, 11070000),
-                'Bayern' : (9, 13080000),
-                'Saarland' : (10, 990509),
-                'Berlin' : (11, 3645000),
-                'Brandenburg' : (12, 2641000),
-                'Mecklenburg-Vorpommern' : (13, 1610000),
-                'Sachsen' : (14, 4078000),
-                'Sachsen-Anhalt' : (15, 2208000),
-                'Thüringen' : (16, 2143000)}
-
-def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein': (1, 2897000),
+                'Hamburg': (2, 1841000),
+                'Niedersachsen': (3, 7982000),
+                'Bremen': (4, 569352),
+                'Nordrhein-Westfalen': (5, 17930000),
+                'Hessen': (6, 6266000),
+                'Rheinland-Pfalz': (7, 4085000),
+                'Baden-Württemberg': (8, 11070000),
+                'Bayern': (9, 13080000),
+                'Saarland': (10, 990509),
+                'Berlin': (11, 3645000),
+                'Brandenburg': (12, 2641000),
+                'Mecklenburg-Vorpommern': (13, 1610000),
+                'Sachsen': (14, 4078000),
+                'Sachsen-Anhalt': (15, 2208000),
+                'Thüringen': (16, 2143000)}
+
+
+def daterange(start_date: date, end_date: date):
+    days = int((end_date - start_date).days)
+    for n in range(days):
+        yield start_date + timedelta(n)
+
+
+def transform_jh_germany_data(plotter: Plotter,
+                              time_range=50,
+                              sample_rate=1,
+                              model='SIR'):
+    N = 83100000
+    state_name = 'Germany'
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+
+    # extract data
+    data_directory = 'datasets/COVID-19/csse_covid_19_data/csse_covid_19_daily_reports'
+    start_date = date(2020, 1, 31)
+    end_date = date(2020, 3, 20)
+    for i, single_date in enumerate(daterange(start_date, end_date)):
+        file_date = single_date.strftime("%m-%d-%Y")
+        date_df = pd.read_csv(data_directory + "/" + file_date + ".csv")
+        date_df = date_df.loc[date_df['Country/Region'] == state_name]
+
+        infections[i] = date_df['Confirmed'].fillna(0).astype(int)
+        deaths[i] = date_df['Deaths'].fillna(0).astype(int)
+        recoveries[i] = date_df['Recovered'].fillna(0).astype(int)
+
+    S, I, R = np.zeros(infections.shape[0]), np.zeros(
+        infections.shape[0]), np.zeros(infections.shape[0])
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - infections[day]
+        I[day] = I[day - 1] + infections[day] - deaths[day] - recoveries[day]
+        R[day] = R[day - 1] + deaths[day] + recoveries[day]
+        if I[day] < 0:
+            I[day] = 0
+
+    t = np.arange(0, time_range, 1)
+
+    plotter.plot(t, [I, R], ["I", "R"], "JH_data", "JH Data", (6, 6))
+
+    groups = [S, I, R]
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_JH_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+
+
+def transform_data(plotter: Plotter, alpha=1 / 14, state_name='Germany', time_range=1200, sample_rate=1, model='SIR'):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
 
     Args:
     Args:
         plotter (Plotter): Plotter object to plot dataset curves.
         plotter (Plotter): Plotter object to plot dataset curves.
         dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
         dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
-        plot_name (str, optional): Name of the plot file. Defaults to ''.
-        plot_title (str, optional): Title of the plot. Defaults to ''.
         sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
         sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
         exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
         exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
-        plot_size (tuple, optional): Size of the plot in (x, y) format. Defaults to (12,6).
-        yscale_log (bool, optional): Controls if the y axis of the plot will be scaled in log scale. Defaults to False.
-        plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     """
     # read the data
     # read the data
 
 
-
     infections = np.zeros(time_range)
     infections = np.zeros(time_range)
     deaths = np.zeros(time_range)
     deaths = np.zeros(time_range)
     recoveries = np.zeros(time_range)
     recoveries = np.zeros(time_range)
     if state_name == 'Germany':
     if state_name == 'Germany':
-        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        df = pd.read_csv(
+            'datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
         N = 83100000
         N = 83100000
         infections[0] = df['Faelle_gesamt'][0]
         infections[0] = df['Faelle_gesamt'][0]
         deaths[0] = df['Todesfaelle_neu'][0]
         deaths[0] = df['Todesfaelle_neu'][0]
 
 
         recovery_queue = np.zeros(14)
         recovery_queue = np.zeros(14)
         for i in range(1, time_range):
         for i in range(1, time_range):
-            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i - 1]
             deaths[i] = df['Todesfaelle_neu'][i]
             deaths[i] = df['Todesfaelle_neu'][i]
             recoveries[i] = recovery_queue[0]
             recoveries[i] = recovery_queue[0]
 
 
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
             recovery_queue[-1] = infections[i]
     else:
     else:
-        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        df = pd.read_csv(
+            'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
         state_ID, N = state_lookup[state_name]
         state_ID, N = state_lookup[state_name]
 
 
         # single out a state
         # single out a state
         state_IDs = df['IdLandkreis'] // 1000
         state_IDs = df['IdLandkreis'] // 1000
         df = df.loc[state_IDs == state_ID]
         df = df.loc[state_IDs == state_ID]
 
 
-        # sort entries by state
+        # sort entries by date
         df = df.sort_values('Refdatum')
         df = df.sort_values('Refdatum')
         df = df.reset_index(drop=True)
         df = df.reset_index(drop=True)
 
 
-        # collect cases    
+        # collect cases
         entry_idx = 0
         entry_idx = 0
         day = 0
         day = 0
         date = df['Refdatum'][entry_idx]
         date = df['Refdatum'][entry_idx]
@@ -78,7 +130,8 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
                 deaths[day] += df['AnzahlTodesfall'][entry_idx]
                 deaths[day] += df['AnzahlTodesfall'][entry_idx]
                 entry_idx += 1
                 entry_idx += 1
             # move day index by difference between the current and next date
             # move day index by difference between the current and next date
-            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            day += (pd.to_datetime(df['Refdatum']
+                    [entry_idx]) - pd.to_datetime(date)).days
             date = df['Refdatum'][entry_idx]
             date = df['Refdatum'][entry_idx]
 
 
         recovery_queue = np.zeros(14)
         recovery_queue = np.zeros(14)
@@ -89,48 +142,122 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
             recovery_queue[-1] = infections[i]
             week_counter -= 1
             week_counter -= 1
-        
+
     df = df.drop(df.index[time_range:])
     df = df.drop(df.index[time_range:])
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     # generate groups
     S[0] = N - infections[0]
     S[0] = N - infections[0]
     I[0] = infections[0]
     I[0] = infections[0]
     R[0] = 0
     R[0] = 0
     if model == 'I':
     if model == 'I':
         for day in range(1, time_range):
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-            R[day] = R[day-1] + I[day-1] * alpha
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - I[day - 1] * alpha
+            R[day] = R[day - 1] + I[day - 1] * alpha
     else:
     else:
         for day in range(1, time_range):
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
-            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - \
+                deaths[day] - recoveries[day]
+            R[day] = R[day - 1] + deaths[day] + recoveries[day]
             if I[day] < 0:
             if I[day] < 0:
                 I[day] = 0
                 I[day] = 0
-    
+
     t = np.arange(0, time_range, 1)
     t = np.arange(0, time_range, 1)
 
 
     # select, which group is to be outputted
     # select, which group is to be outputted
     groups = []
     groups = []
     if 'S' in model:
     if 'S' in model:
         groups.append(S)
         groups.append(S)
-    
+
     if 'I' in model:
     if 'I' in model:
         groups.append(I)
         groups.append(I)
 
 
     if 'R' in model:
     if 'R' in model:
         groups.append(R)
         groups.append(R)
 
 
-    plotter.plot(t, 
-                 groups, 
-                 [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
-                 state_name, 
-                 (6,6), 
-                 xlabel='time / days', 
+    plotter.plot(t,
+                 groups,
+                 [*model],
+                 state_name.replace(' ', '_').replace(
+                     '-', '_').replace('ü', 'ue') + f"_{model}" + f"_{int(1/alpha)}",
+                 state_name,
+                 (6, 6),
+                 xlabel='time / days',
                  ylabel='amount of people')
                  ylabel='amount of people')
 
 
-    COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+
+
+def transform_paper_data():
+    N = 70000000
+    time_range = 36
+    alpha = 0.07
+    state_name = 'Germany'
+
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    # Data
+    data = [
+        [1.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.60000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.70000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.80000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.90000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.00000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.10000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.20000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.60000000e+01, 2.00000000e+00, 1.70000000e+01],
+        [2.70000000e+01, 2.00000000e+00, 2.10000000e+01],
+        [2.80000000e+01, 2.00000000e+00, 4.70000000e+01],
+        [2.90000000e+01, 2.00000000e+00, 5.70000000e+01],
+        [1.00000000e+00, 3.00000000e+00, 1.11000000e+02],
+        [2.00000000e+00, 3.00000000e+00, 1.29000000e+02],
+        [3.00000000e+00, 3.00000000e+00, 1.57000000e+02],
+        [4.00000000e+00, 3.00000000e+00, 1.96000000e+02],
+        [5.00000000e+00, 3.00000000e+00, 2.62000000e+02],
+        [6.00000000e+00, 3.00000000e+00, 4.00000000e+02],
+        [7.00000000e+00, 3.00000000e+00, 6.84000000e+02],
+        [8.00000000e+00, 3.00000000e+00, 8.47000000e+02],
+        [9.00000000e+00, 3.00000000e+00, 9.02000000e+02],
+        [1.00000000e+01, 3.00000000e+00, 1.13900000e+03],
+        [1.10000000e+01, 3.00000000e+00, 1.29600000e+03],
+        [1.20000000e+01, 3.00000000e+00, 1.56700000e+03],
+        [1.30000000e+01, 3.00000000e+00, 2.36900000e+03],
+        [1.40000000e+01, 3.00000000e+00, 3.06200000e+03],
+        [1.50000000e+01, 3.00000000e+00, 3.79500000e+03],
+        [1.60000000e+01, 3.00000000e+00, 4.83800000e+03],
+        [1.70000000e+01, 3.00000000e+00, 6.01200000e+03],
+        [1.80000000e+01, 3.00000000e+00, 7.15600000e+03],
+        [1.90000000e+01, 3.00000000e+00, 8.19800000e+03],
+    ]
+
+    # Creating a Pandas DataFrame
+    df = pd.DataFrame(data, columns=["Day", "Month", "Infected people"])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
+    # generate groups
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - df["Infected people"][day]
+        I[day] = I[day - 1] + df["Infected people"][day] - I[day - 1] * alpha
+        R[day] = R[day - 1] + I[day - 1] * alpha
+
+    COVID_Data = np.asarray([np.arange(0, time_range, 1)] +
+                            [S, I, R])
 
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+    np.savetxt(
+        f"datasets/SIR_Paper_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")

+ 32 - 4
src/problem.py

@@ -1,6 +1,7 @@
 import torch
 import torch
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 
 
+
 class PandemicProblem:
 class PandemicProblem:
     def __init__(self, data: PandemicDataset) -> None:
     def __init__(self, data: PandemicDataset) -> None:
         """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
         """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
@@ -18,14 +19,14 @@ class PandemicProblem:
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
         """
         assert self._gradients != None, 'Gradientmatrix need to be defined'
         assert self._gradients != None, 'Gradientmatrix need to be defined'
-        
 
 
-    def def_grad_matrix(self, number:int):
+    def def_grad_matrix(self, number: int):
         assert self._gradients == None, 'Gradientmatrix is already defined'
         assert self._gradients == None, 'Gradientmatrix is already defined'
         self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
         self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
         for i in range(number):
         for i in range(number):
             self._gradients[i][:, i] = 1
             self._gradients[i][:, i] = 1
 
 
+
 class SIRProblem(PandemicProblem):
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
         super().__init__(data)
@@ -53,8 +54,36 @@ class SIRProblem(PandemicProblem):
         return S_residual, I_residual, R_residual
         return S_residual, I_residual, R_residual
 
 
 
 
+class SIRAlphaProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset, alpha):
+        super().__init__(data)
+        self.alpha = alpha
+
+    def residual(self, SIR_pred, beta):
+        super().residual()
+        SIR_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[2], retain_graph=True)
+        dRdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+
+        S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
+        R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
+
+        return S_residual, I_residual, R_residual
+
+
 class ReducedSIRProblem(PandemicProblem):
 class ReducedSIRProblem(PandemicProblem):
-    def __init__(self, data: PandemicDataset, alpha:float):
+    def __init__(self, data: PandemicDataset, alpha: float):
         super().__init__(data)
         super().__init__(data)
         self.alpha = alpha
         self.alpha = alpha
 
 
@@ -72,4 +101,3 @@ class ReducedSIRProblem(PandemicProblem):
 
 
         I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         return I_residual
         return I_residual
-

+ 56 - 56
states_training.py

@@ -7,7 +7,7 @@ from src.dataset import PandemicDataset, Norms
 from src.problem import ReducedSIRProblem
 from src.problem import ReducedSIRProblem
 from src.dinn import DINN, Scheduler, Activation
 from src.dinn import DINN, Scheduler, Activation
 
 
-ALPHA = [1/14, 1/5]
+ALPHA = [1 / 14, 1 / 5]
 DO_STATES = True
 DO_STATES = True
 DO_SYNTHETIC = False
 DO_SYNTHETIC = False
 
 
@@ -16,66 +16,66 @@ ITERATIONS = 13
 state_starting_index = 0
 state_starting_index = 0
 
 
 if "1" in sys.argv:
 if "1" in sys.argv:
-     state_starting_index = 8
-
-STATE_LOOKUP = {'Schleswig_Holstein' : 2897000,
-                'Hamburg' : 1841000, 
-                'Niedersachsen' : 7982000, 
-                'Bremen' : 569352,
-                'Nordrhein_Westfalen' : 17930000,
-                'Hessen' : 6266000,
-                'Rheinland_Pfalz' : 4085000,
-                'Baden_Wuerttemberg' : 11070000,
-                'Bayern' : 13080000,
-                'Saarland' : 990509,
-                'Berlin' : 3645000,
-                'Brandenburg' : 2641000,
-                'Mecklenburg_Vorpommern' : 1610000,
-                'Sachsen' : 4078000,
-                'Sachsen_Anhalt' : 2208000,
-                'Thueringen' : 2143000}
+    state_starting_index = 8
+
+STATE_LOOKUP = {'Schleswig_Holstein': 2897000,
+                'Hamburg': 1841000,
+                'Niedersachsen': 7982000,
+                'Bremen': 569352,
+                'Nordrhein_Westfalen': 17930000,
+                'Hessen': 6266000,
+                'Rheinland_Pfalz': 4085000,
+                'Baden_Wuerttemberg': 11070000,
+                'Bayern': 13080000,
+                'Saarland': 990509,
+                'Berlin': 3645000,
+                'Brandenburg': 2641000,
+                'Mecklenburg_Vorpommern': 1610000,
+                'Sachsen': 4078000,
+                'Sachsen_Anhalt': 2208000,
+                'Thueringen': 2143000}
 
 
 if DO_SYNTHETIC:
 if DO_SYNTHETIC:
-    alpha = 1/3
+    alpha = 1 / 3
     covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
     covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
     for i in range(ITERATIONS):
     for i in range(ITERATIONS):
-        dataset = PandemicDataset('Synthetic I', 
-                                    ['I'], 
-                                    7.6e6, 
-                                    *covid_data, 
-                                    norm_name=Norms.CONSTANT, 
-                                    use_scaled_time=True)
+        dataset = PandemicDataset('Synthetic I',
+                                  ['I'],
+                                  7.6e6,
+                                  *covid_data,
+                                  norm_name=Norms.CONSTANT,
+                                  use_scaled_time=True)
 
 
         problem = ReducedSIRProblem(dataset, alpha)
         problem = ReducedSIRProblem(dataset, alpha)
-        dinn = DINN(2, 
-                    dataset, 
-                    [], 
-                    problem, 
-                    None, 
-                    state_variables=['R_t'], 
-                    hidden_size=100, 
-                    hidden_layers=4, 
+        dinn = DINN(2,
+                    dataset,
+                    [],
+                    problem,
+                    None,
+                    state_variables=['R_t'],
+                    hidden_size=100,
+                    hidden_layers=4,
                     activation_layer=torch.nn.Tanh(),
                     activation_layer=torch.nn.Tanh(),
                     activation_output=Activation.POWER)
                     activation_output=Activation.POWER)
 
 
-        dinn.configure_training(1e-3, 
-                                20000, 
-                                scheduler_class=Scheduler.POLYNOMIAL, 
+        dinn.configure_training(1e-3,
+                                20000,
+                                scheduler_class=Scheduler.POLYNOMIAL,
                                 lambda_physics=1e-6,
                                 lambda_physics=1e-6,
                                 verbose=True)
                                 verbose=True)
         dinn.train(verbose=True, do_split_training=True)
         dinn.train(verbose=True, do_split_training=True)
 
 
         dinn.save_training_process(f'synthetic_{i}')
         dinn.save_training_process(f'synthetic_{i}')
-        #r_t = dinn.get_output(1).detach().cpu().numpy()
+        # r_t = dinn.get_output(1).detach().cpu().numpy()
 
 
-        #with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
-                #writer = csv.writer(csvfile, delimiter=',')
-                #writer.writerow(r_t)
+        # with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
+        # writer = csv.writer(csvfile, delimiter=',')
+        # writer.writerow(r_t)
 
 
 for iteration in range(ITERATIONS):
 for iteration in range(ITERATIONS):
     if iteration <= 2:
     if iteration <= 2:
-         print('skip first three iteration, as it was already done')
-         continue
+        print('skip first three iteration, as it was already done')
+        continue
     if DO_STATES:
     if DO_STATES:
         for state_idx in range(state_starting_index, state_starting_index + 8):
         for state_idx in range(state_starting_index, state_starting_index + 8):
             state = list(STATE_LOOKUP.keys())[state_idx]
             state = list(STATE_LOOKUP.keys())[state_idx]
@@ -91,22 +91,22 @@ for iteration in range(ITERATIONS):
 
 
                 problem = ReducedSIRProblem(dataset, alpha)
                 problem = ReducedSIRProblem(dataset, alpha)
 
 
-                dinn = DINN(2, 
-                            dataset, 
-                            [], 
-                            problem, 
-                            None, 
-                            state_variables=['R_t'], 
-                            hidden_size=100, 
-                            hidden_layers=4, 
+                dinn = DINN(2,
+                            dataset,
+                            [],
+                            problem,
+                            None,
+                            state_variables=['R_t'],
+                            hidden_size=100,
+                            hidden_layers=4,
                             activation_layer=torch.nn.Tanh(),
                             activation_layer=torch.nn.Tanh(),
                             activation_output=Activation.POWER)
                             activation_output=Activation.POWER)
 
 
-                dinn.configure_training(1e-3, 
-                                        25000, 
+                dinn.configure_training(1e-3,
+                                        25000,
                                         scheduler_class=Scheduler.POLYNOMIAL,
                                         scheduler_class=Scheduler.POLYNOMIAL,
-                                        lambda_obs=1e2, 
-                                        lambda_physics=1e-6, 
+                                        lambda_obs=1e2,
+                                        lambda_physics=1e-6,
                                         verbose=True)
                                         verbose=True)
                 dinn.train(verbose=True, do_split_training=True)
                 dinn.train(verbose=True, do_split_training=True)
 
 
@@ -115,4 +115,4 @@ for iteration in range(ITERATIONS):
                 r_t = dinn.get_output(1).detach().cpu().numpy()
                 r_t = dinn.get_output(1).detach().cpu().numpy()
                 with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
                 with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
                     writer = csv.writer(csvfile, delimiter=',')
                     writer = csv.writer(csvfile, delimiter=',')
-                    writer.writerow(r_t)
+                    writer.writerow(r_t)