3 Commits db75cf404e ... f2ced484b6

Autor SHA1 Mensagem Data
  phillip.rothenbeck f2ced484b6 reformat 7 meses atrás
  phillip.rothenbeck c97674cbe1 set alpha sir problem 7 meses atrás
  phillip.rothenbeck 84e7847058 get data from paper 7 meses atrás
6 arquivos alterados com 400 adições e 236 exclusões
  1. 32 30
      src/dataset.py
  2. 71 70
      src/plotter.py
  3. 33 27
      src/preprocessing/synthetic_data.py
  4. 176 49
      src/preprocessing/transform_data.py
  5. 32 4
      src/problem.py
  6. 56 56
      states_training.py

+ 32 - 30
src/dataset.py

@@ -1,20 +1,23 @@
 import torch
+import numpy as np
 from enum import Enum
 
+
 class Norms(Enum):
-    POPULATION=0
-    MIN_MAX=1
-    CONSTANT=2
+    POPULATION = 0
+    MIN_MAX = 1
+    CONSTANT = 2
+
 
 class PandemicDataset:
-    def __init__(self, 
-                 name:str,
-                 group_names:list, 
-                 N: int, 
-                 t, 
-                 *groups, 
+    def __init__(self,
+                 name: str,
+                 group_names: list,
+                 N: int,
+                 t,
+                 *groups,
                  norm_name=Norms.MIN_MAX,
-                 C = 10**5,
+                 C=10**5,
                  use_scaled_time=False):
         """Class to hold all data for one training process.
 
@@ -60,12 +63,12 @@ class PandemicDataset:
 
         self.__group_dict = {}
         for i, name in enumerate(group_names):
-            self.__group_dict.update({name : i})
+            self.__group_dict.update({name: i})
 
         self.__group_names = group_names
 
         self.__groups = [torch.tensor(group, device=self.device_name) for group in groups]
-        
+
         self.__mins = [torch.min(group) for group in self.__groups]
         self.__maxs = [torch.max(group) for group in self.__groups]
         self.__norms = self.__norm(self.__groups)
@@ -73,50 +76,49 @@ class PandemicDataset:
     @property
     def number_groups(self):
         return len(self.__group_names)
-    
+
     @property
     def data(self):
         return self.__groups
-    
+
     @property
-    def group_names(self):
+    def group_names(self) -> np.ndarray:
         return self.__group_names
-    
+
     def __population_norm(self, data):
         return [(data[i] / self.N) for i in range(self.number_groups)]
-    
+
     def __population_denorm(self, data):
         return [(data[i] * self.N) for i in range(self.number_groups)]
 
     def __min_max_norm(self, data):
         return [(data[i] - self.__mins[i]) / (self.__maxs[i] - self.__mins[i]) for i in range(self.number_groups)]
-    
+
     def __min_max_denorm(self, data):
         return [(self.__mins[i] + (self.__maxs[i] - self.__mins[i]) * data[i]) for i in range(self.number_groups)]
-    
+
     def __constant_norm(self, data):
         return [(data[i] / self.C) for i in range(self.number_groups)]
 
     def __constant_denorm(self, data):
         return [(data[i] * self.C) for i in range(self.number_groups)]
 
-    def get_normalized_data(self, data:list):
+    def get_normalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__norm(data)
-    
-    def get_denormalized_data(self, data:list):
+
+    def get_denormalized_data(self, data: list):
         assert len(data) == self.number_groups, f'data parameter needs same length as there are groups in the dataset ({self.number_groups})'
         return self.__denorm(data)
 
-    def get_group(self, name:str):
+    def get_group(self, name: str):
         return self.__groups[self.__group_dict[name]]
-    
-    def get_min(self, name:str):
+
+    def get_min(self, name: str):
         return self.__mins[self.__group_dict[name]]
-    
-    def get_max(self, name:str):
+
+    def get_max(self, name: str):
         return self.__maxs[self.__group_dict[name]]
-    
-    def get_norm(self, name:str):
+
+    def get_norm(self, name: str):
         return self.__norms[self.__group_dict[name]]
-    

+ 71 - 70
src/plotter.py

@@ -3,7 +3,7 @@ import torch
 import imageio
 import numpy as np
 import matplotlib.pyplot as plt
-import matplotlib.ticker as ticker 
+import matplotlib.ticker as ticker
 
 from matplotlib import rcParams
 from itertools import cycle
@@ -15,6 +15,7 @@ SUSCEPTIBLE_COLOR = '#6399f7'
 INFECTIOUS_COLOR = '#f56262'
 REMOVED_COLOR = '#83eb5e'
 
+
 class Plotter:
     def __init__(self, additional_colors=[], font_size=20, font='serif', font_color='#000000') -> None:
         """Plotter of scientific plots and animations, for dinn.py.
@@ -65,27 +66,27 @@ class Plotter:
         """Delete all frames that were saved.
         """
         self.__frames = []
-    
-    #TODO comments
-    def plot(self, 
-             x, 
-             y:list, 
-             labels:list,
-             file_name:str,
-             title:str, 
-             figure_shape:tuple, 
+
+    # TODO comments
+    def plot(self,
+             x,
+             y: list,
+             labels: list,
+             file_name: str,
+             title: str,
+             figure_shape: tuple,
              event_lookup={},
-             is_frame=False, 
-             is_background=[], 
+             is_frame=False,
+             is_background=[],
              fill_between=[],
-             plot_legend=True, 
-             y_log_scale=False, 
-             lw=3, 
-             legend_loc='best', 
+             plot_legend=True,
+             y_log_scale=False,
+             lw=3,
+             legend_loc='best',
              ylim=(None, None),
-             number_xlabels = 5,
+             number_xlabels=5,
              xlabel='',
-             ylabel='', 
+             ylabel='',
              xlabel_rotation=None):
         """Plotting method.
 
@@ -112,14 +113,14 @@ class Plotter:
 
         ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
         ax.set_facecolor('xkcd:white')
-        #ax.yaxis.set_tick_params(length=0, which='both')
-        #ax.xaxis.set_tick_params(length=0, which='both')
+        # ax.yaxis.set_tick_params(length=0, which='both')
+        # ax.xaxis.set_tick_params(length=0, which='both')
 
-        #ax.grid(which='major', c='black', lw=0.2, ls='-')
+        # ax.grid(which='major', c='black', lw=0.2, ls='-')
         ax.set_title(title)
 
-        #for spine in ('top', 'right', 'bottom', 'left'):
-         #   ax.spines[spine].set_visible(False)
+        # for spine in ('top', 'right', 'bottom', 'left'):
+        #   ax.spines[spine].set_visible(False)
 
         if torch.is_tensor(x):
             x = x.cpu().detach().numpy()
@@ -131,24 +132,24 @@ class Plotter:
             if len(is_background) != 0:
                 if is_background[i]:
                     alpha = 0.25
-            
+
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
             else:
                 data = array
 
             space = int(len(x) / number_xlabels)
-            ax.xaxis.set_major_locator(ticker.MultipleLocator(space)) 
-        
+            ax.xaxis.set_major_locator(ticker.MultipleLocator(space))
+
             ax.plot(x, data, label=labels[i], alpha=alpha, lw=lw, linestyle=next(linecycler), c=self.__colors[i % len(self.__colors)])
             if i < len(fill_between):
-                ax.fill_between(x, data+fill_between[i], data-fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
+                ax.fill_between(x, data + fill_between[i], data - fill_between[i], facecolor=self.__colors[i % len(self.__colors)], alpha=0.5)
             j = i
 
         for event in event_lookup.keys():
             j += 1
-            plt.axvline(x = event_lookup[event], color = self.__colors[j % len(self.__colors)], label = event, ls=next(linecycler))
-            
+            plt.axvline(x=event_lookup[event], color=self.__colors[j % len(self.__colors)], label=event, ls=next(linecycler))
+
         if plot_legend:
             plt.legend(loc=legend_loc)
 
@@ -166,7 +167,7 @@ class Plotter:
 
         if xlabel_rotation != None:
             plt.xticks(rotation=xlabel_rotation)
- 
+
         if not os.path.exists(FRAME_DIR):
             os.makedirs(FRAME_DIR)
 
@@ -181,29 +182,29 @@ class Plotter:
 
         plt.close()
 
-    def cluster_plot(self, 
-                     x, 
-                     y:list, 
-                     labels, 
-                     shape, 
-                     plots_shape, 
-                     file_name:str, 
-                     titles:list, 
+    def cluster_plot(self,
+                     x,
+                     y: list,
+                     labels,
+                     shape,
+                     plots_shape,
+                     file_name: str,
+                     titles: list,
                      number_xlabels=5,
-                     lw=3, 
+                     lw=3,
                      fill_between=[],
                      event_lookup={},
                      xlabel='',
                      ylabel='',
                      ylim=(None, None),
                      y_lim_exception=None,
-                     y_log_scale=False, 
-                     legend_loc=(0.5,0.992),
+                     y_log_scale=False,
+                     legend_loc=(0.5, 0.992),
                      add_y_space=0.05,
                      number_of_legend_columns=1,
                      same_axes=True,
                      free_axis=(None, None),
-                     plot_all_labels = True):
+                     plot_all_labels=True):
         real_shape = (shape[1] * plots_shape[0], shape[0] * plots_shape[1])
         fig, axes = plt.subplots(*shape, figsize=real_shape, sharex=same_axes, sharey=same_axes)
         plot_idx = 0
@@ -224,25 +225,25 @@ class Plotter:
 
                     space = int(len(x) / number_xlabels)
                     axes[i].xaxis.set_major_locator(ticker.MultipleLocator(space))
-                    axes[i].plot(x, 
-                                 data, 
+                    axes[i].plot(x,
+                                 data,
                                  linestyle=next(linecycler),
                                  label=labels[j],
                                  c=color,
                                  lw=lw)
                     axes[i].set_title(titles[i])
                     if j < len(fill_between[i]):
-                        axes[i].fill_between(x, data+fill_between[i][j], data-fill_between[i][j], facecolor=color, alpha=0.5)
+                        axes[i].fill_between(x, data + fill_between[i][j], data - fill_between[i][j], facecolor=color, alpha=0.5)
                 for event in event_lookup.keys():
-                    axes[i].axvline(x=event_lookup[event], 
-                                    color=next(colorcycler), 
-                                    label=event, 
+                    axes[i].axvline(x=event_lookup[event],
+                                    color=next(colorcycler),
+                                    label=event,
                                     ls=next(linecycler),
                                     lw=lw)
-            
+
                 if ylim[0] != None and y_lim_exception != i:
                     axes[i].set_ylim(ylim)
-                
+
                 if y_log_scale:
                     plt.yscale('log')
         else:
@@ -265,11 +266,11 @@ class Plotter:
                                     c = len(data)
                                 else:
                                     c = len(x)
-                                axes[i, j].plot(x[:c], 
-                                                data, 
-                                                label=labels[k], 
-                                                c=next(colorcycler), 
-                                                lw=lw, 
+                                axes[i, j].plot(x[:c],
+                                                data,
+                                                label=labels[k],
+                                                c=next(colorcycler),
+                                                lw=lw,
                                                 linestyle=next(linecycler))
                             axes[i, j].set_title(titles[plot_idx])
                             if ylim[0] != None:
@@ -292,23 +293,23 @@ class Plotter:
                 ax.label_outer()
 
         # Adjust layout to prevent overlap
-        plt.tight_layout(rect=[0, 0, 1, 1-add_y_space])
+        plt.tight_layout(rect=[0, 0, 1, 1 - add_y_space])
         plt.savefig(VISUALISATION_DIR + f'{file_name}.pdf')
 
-    def scatter(self, 
-                x, 
-                y:list, 
-                labels:list, 
-                figure_shape:tuple, 
-                file_name:str, 
-                title:str, 
-                std=[], 
+    def scatter(self,
+                x,
+                y: list,
+                labels: list,
+                figure_shape: tuple,
+                file_name: str,
+                title: str,
+                std=[],
                 true_values=[],
                 true_label='true',
-                plot_legend=True, 
+                plot_legend=True,
                 legend_loc='best',
                 xlabel='',
-                ylabel='', 
+                ylabel='',
                 xlabel_rotation=None):
         assert len(y) == len(labels), f"There must be the same amount of labels as there are plots.\nNumber Plots: {len(y)}\nNumber Labels: {len(labels)}"
         fig = self.__generate_figure(shape=figure_shape)
@@ -322,7 +323,7 @@ class Plotter:
 
         markercycler = cycle(self.__marker_styles)
         for i, array in enumerate(y):
-            
+
             if torch.is_tensor(array):
                 data = array.cpu().detach().numpy()
             else:
@@ -332,11 +333,11 @@ class Plotter:
                 ax.scatter(x, data, label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
             if i < len(std):
                 ax.errorbar(x, data, std[i], label=labels[i], c=self.__colors[i % len(self.__colors)], linestyle='None', marker=next(markercycler))
-        
+
         linecycler = cycle(self.__lines_styles)
         for i, true_value in enumerate(true_values):
-            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}',c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
-            
+            ax.plot(x, np.ones_like(x, dtype='float64') * true_value, label=f'{true_label} {labels[i]}', c=self.__colors[i % len(self.__colors)], ls=next(linecycler))
+
         if plot_legend:
             plt.legend(loc=legend_loc)
 

+ 33 - 27
src/preprocessing/synthetic_data.py

@@ -5,8 +5,9 @@ from scipy.integrate import odeint
 
 from src.plotter import Plotter
 
+
 class SyntheticDeseaseData:
-    def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
+    def __init__(self, simulation_time: int, time_points: int, plotter: Plotter):
         """This class is the parent class for every class, that is supposed to generate synthetic data.
 
         Args:
@@ -29,7 +30,7 @@ class SyntheticDeseaseData:
         """
         self.generated = True
 
-    def plot(self, labels: tuple, title:str, file_name:str):
+    def plot(self, labels: tuple, title: str, file_name: str, leave_out_indices):
         """Plot the data which was generated.
 
         Args:
@@ -37,13 +38,20 @@ class SyntheticDeseaseData:
             title (str): The name of the plot.
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
+        groups = []
+        used_labels = []
+        for i, group in enumerate(self.data):
+            if not i in leave_out_indices:
+                groups.append(group)
+                used_labels.append(labels[i])
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
-        else: 
+            self.plotter.plot(self.t, groups, used_labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+        else:
             print('Data has to be generated before plotting!')
 
+
 class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
         """This class is able to generate synthetic data for the SIR model.
 
         Args:
@@ -78,8 +86,8 @@ class SIR(SyntheticDeseaseData):
             tuple: Change amount for each group.
         """
         S, I, _ = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dSdt = -self.beta * ((S * I) / self.N)  # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I  # self.beta * S * I - self.alpha * I
         dRdt = self.alpha * I
         return dSdt, dIdt, dRdt
 
@@ -90,21 +98,22 @@ class SIR(SyntheticDeseaseData):
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
 
-    def plot(self, title='', file_name='SIR_plot'):
+    def plot(self, title='', file_name='SIR_plot', leave_out_indices=[]):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name, leave_out_indices=leave_out_indices)
 
     def save(self, name=''):
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
             np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
 
+
 class I(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
+    def __init__(self, plotter: Plotter, N: int, C: int, I_0=1, time_points=100, alpha=1 / 3) -> None:
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
         Args:
@@ -119,7 +128,7 @@ class I(SyntheticDeseaseData):
         self.N = N
         self.C = C
         self.I_0 = I_0
- 
+
         self.alpha = alpha
 
         self.t = np.linspace(0, 1, time_points)
@@ -129,14 +138,12 @@ class I(SyntheticDeseaseData):
         self.data = None
         self.generated = False
         self.plotter = plotter
-        
+
     def R_t(self, t):
         descaled_t = t * self.t_f
         # if descaled_t < threshold1:
         return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
-
-            
     def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
 
@@ -153,10 +160,10 @@ class I(SyntheticDeseaseData):
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """
-        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = odeint(self.differential_eq, self.I_0 / self.C, self.t).T
         self.data = self.data[0] * self.C
         self.t_counter = 0
-        self.generated =True
+        self.generated = True
 
     def plot(self, title='', file_name=''):
         """Plot the data which was generated.
@@ -167,21 +174,20 @@ class I(SyntheticDeseaseData):
             for time in self.t:
                 self.reproduction_value.append(self.R_t(time))
             self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
-        else: 
+        else:
             print('Data has to be generated before plotting!')
 
     def save(self, name=''):
         if self.generated:
-            COVID_Data = np.asarray([self.t_save, self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data])
 
             np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
 
-        
 
 class SIDR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
         """This class is able to generate synthetic data for the SIDR model.
 
         Args:
@@ -207,7 +213,7 @@ class SIDR(SyntheticDeseaseData):
         self.gamma = gamma
 
         super().__init__(simulation_time, time_points, plotter)
-    
+
     def differential_eq(self, y, t, alpha, beta, gamma):
         """In this function implements the differential equation of the SIDR model will be implemented.
 
@@ -223,7 +229,7 @@ class SIDR(SyntheticDeseaseData):
         """
         S, I, D, R = y
         dSdt = - (self.alpha / self.N) * S * I
-        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I 
+        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
         dDdt = self.gamma * I
         dRdt = self.beta * I
         return dSdt, dIdt, dDdt, dRdt
@@ -242,8 +248,8 @@ class SIDR(SyntheticDeseaseData):
 
     def save(self, name=''):
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
             np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')

+ 176 - 49
src/preprocessing/transform_data.py

@@ -1,72 +1,124 @@
 import numpy as np
 import pandas as pd
+from datetime import date, timedelta
 
 from src.plotter import Plotter
 
-state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                'Hamburg' : (2, 1841000), 
-                'Niedersachsen' : (3, 7982000), 
-                'Bremen' : (4, 569352),
-                'Nordrhein-Westfalen' : (5, 17930000),
-                'Hessen' : (6, 6266000),
-                'Rheinland-Pfalz' : (7, 4085000),
-                'Baden-Württemberg' : (8, 11070000),
-                'Bayern' : (9, 13080000),
-                'Saarland' : (10, 990509),
-                'Berlin' : (11, 3645000),
-                'Brandenburg' : (12, 2641000),
-                'Mecklenburg-Vorpommern' : (13, 1610000),
-                'Sachsen' : (14, 4078000),
-                'Sachsen-Anhalt' : (15, 2208000),
-                'Thüringen' : (16, 2143000)}
-
-def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein': (1, 2897000),
+                'Hamburg': (2, 1841000),
+                'Niedersachsen': (3, 7982000),
+                'Bremen': (4, 569352),
+                'Nordrhein-Westfalen': (5, 17930000),
+                'Hessen': (6, 6266000),
+                'Rheinland-Pfalz': (7, 4085000),
+                'Baden-Württemberg': (8, 11070000),
+                'Bayern': (9, 13080000),
+                'Saarland': (10, 990509),
+                'Berlin': (11, 3645000),
+                'Brandenburg': (12, 2641000),
+                'Mecklenburg-Vorpommern': (13, 1610000),
+                'Sachsen': (14, 4078000),
+                'Sachsen-Anhalt': (15, 2208000),
+                'Thüringen': (16, 2143000)}
+
+
+def daterange(start_date: date, end_date: date):
+    days = int((end_date - start_date).days)
+    for n in range(days):
+        yield start_date + timedelta(n)
+
+
+def transform_jh_germany_data(plotter: Plotter,
+                              time_range=50,
+                              sample_rate=1,
+                              model='SIR'):
+    N = 83100000
+    state_name = 'Germany'
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+
+    # extract data
+    data_directory = 'datasets/COVID-19/csse_covid_19_data/csse_covid_19_daily_reports'
+    start_date = date(2020, 1, 31)
+    end_date = date(2020, 3, 20)
+    for i, single_date in enumerate(daterange(start_date, end_date)):
+        file_date = single_date.strftime("%m-%d-%Y")
+        date_df = pd.read_csv(data_directory + "/" + file_date + ".csv")
+        date_df = date_df.loc[date_df['Country/Region'] == state_name]
+
+        infections[i] = date_df['Confirmed'].fillna(0).astype(int)
+        deaths[i] = date_df['Deaths'].fillna(0).astype(int)
+        recoveries[i] = date_df['Recovered'].fillna(0).astype(int)
+
+    S, I, R = np.zeros(infections.shape[0]), np.zeros(
+        infections.shape[0]), np.zeros(infections.shape[0])
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - infections[day]
+        I[day] = I[day - 1] + infections[day] - deaths[day] - recoveries[day]
+        R[day] = R[day - 1] + deaths[day] + recoveries[day]
+        if I[day] < 0:
+            I[day] = 0
+
+    t = np.arange(0, time_range, 1)
+
+    plotter.plot(t, [I, R], ["I", "R"], "JH_data", "JH Data", (6, 6))
+
+    groups = [S, I, R]
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_JH_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+
+
+def transform_data(plotter: Plotter, alpha=1 / 14, state_name='Germany', time_range=1200, sample_rate=1, model='SIR'):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
     Args:
         plotter (Plotter): Plotter object to plot dataset curves.
         dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
-        plot_name (str, optional): Name of the plot file. Defaults to ''.
-        plot_title (str, optional): Title of the plot. Defaults to ''.
         sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
         exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
-        plot_size (tuple, optional): Size of the plot in (x, y) format. Defaults to (12,6).
-        yscale_log (bool, optional): Controls if the y axis of the plot will be scaled in log scale. Defaults to False.
-        plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     # read the data
 
-
     infections = np.zeros(time_range)
     deaths = np.zeros(time_range)
     recoveries = np.zeros(time_range)
     if state_name == 'Germany':
-        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        df = pd.read_csv(
+            'datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
         N = 83100000
         infections[0] = df['Faelle_gesamt'][0]
         deaths[0] = df['Todesfaelle_neu'][0]
 
         recovery_queue = np.zeros(14)
         for i in range(1, time_range):
-            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i - 1]
             deaths[i] = df['Todesfaelle_neu'][i]
             recoveries[i] = recovery_queue[0]
 
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
     else:
-        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        df = pd.read_csv(
+            'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
         state_ID, N = state_lookup[state_name]
 
         # single out a state
         state_IDs = df['IdLandkreis'] // 1000
         df = df.loc[state_IDs == state_ID]
 
-        # sort entries by state
+        # sort entries by date
         df = df.sort_values('Refdatum')
         df = df.reset_index(drop=True)
 
-        # collect cases    
+        # collect cases
         entry_idx = 0
         day = 0
         date = df['Refdatum'][entry_idx]
@@ -78,7 +130,8 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
                 deaths[day] += df['AnzahlTodesfall'][entry_idx]
                 entry_idx += 1
             # move day index by difference between the current and next date
-            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            day += (pd.to_datetime(df['Refdatum']
+                    [entry_idx]) - pd.to_datetime(date)).days
             date = df['Refdatum'][entry_idx]
 
         recovery_queue = np.zeros(14)
@@ -89,48 +142,122 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
             week_counter -= 1
-        
+
     df = df.drop(df.index[time_range:])
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     S[0] = N - infections[0]
     I[0] = infections[0]
     R[0] = 0
     if model == 'I':
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-            R[day] = R[day-1] + I[day-1] * alpha
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - I[day - 1] * alpha
+            R[day] = R[day - 1] + I[day - 1] * alpha
     else:
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
-            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - \
+                deaths[day] - recoveries[day]
+            R[day] = R[day - 1] + deaths[day] + recoveries[day]
             if I[day] < 0:
                 I[day] = 0
-    
+
     t = np.arange(0, time_range, 1)
 
     # select, which group is to be outputted
     groups = []
     if 'S' in model:
         groups.append(S)
-    
+
     if 'I' in model:
         groups.append(I)
 
     if 'R' in model:
         groups.append(R)
 
-    plotter.plot(t, 
-                 groups, 
-                 [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
-                 state_name, 
-                 (6,6), 
-                 xlabel='time / days', 
+    plotter.plot(t,
+                 groups,
+                 [*model],
+                 state_name.replace(' ', '_').replace(
+                     '-', '_').replace('ü', 'ue') + f"_{model}" + f"_{int(1/alpha)}",
+                 state_name,
+                 (6, 6),
+                 xlabel='time / days',
                  ylabel='amount of people')
 
-    COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+
+
+def transform_paper_data():
+    N = 70000000
+    time_range = 36
+    alpha = 0.07
+    state_name = 'Germany'
+
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    # Data
+    data = [
+        [1.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.60000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.70000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.80000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.90000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.00000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.10000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.20000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.60000000e+01, 2.00000000e+00, 1.70000000e+01],
+        [2.70000000e+01, 2.00000000e+00, 2.10000000e+01],
+        [2.80000000e+01, 2.00000000e+00, 4.70000000e+01],
+        [2.90000000e+01, 2.00000000e+00, 5.70000000e+01],
+        [1.00000000e+00, 3.00000000e+00, 1.11000000e+02],
+        [2.00000000e+00, 3.00000000e+00, 1.29000000e+02],
+        [3.00000000e+00, 3.00000000e+00, 1.57000000e+02],
+        [4.00000000e+00, 3.00000000e+00, 1.96000000e+02],
+        [5.00000000e+00, 3.00000000e+00, 2.62000000e+02],
+        [6.00000000e+00, 3.00000000e+00, 4.00000000e+02],
+        [7.00000000e+00, 3.00000000e+00, 6.84000000e+02],
+        [8.00000000e+00, 3.00000000e+00, 8.47000000e+02],
+        [9.00000000e+00, 3.00000000e+00, 9.02000000e+02],
+        [1.00000000e+01, 3.00000000e+00, 1.13900000e+03],
+        [1.10000000e+01, 3.00000000e+00, 1.29600000e+03],
+        [1.20000000e+01, 3.00000000e+00, 1.56700000e+03],
+        [1.30000000e+01, 3.00000000e+00, 2.36900000e+03],
+        [1.40000000e+01, 3.00000000e+00, 3.06200000e+03],
+        [1.50000000e+01, 3.00000000e+00, 3.79500000e+03],
+        [1.60000000e+01, 3.00000000e+00, 4.83800000e+03],
+        [1.70000000e+01, 3.00000000e+00, 6.01200000e+03],
+        [1.80000000e+01, 3.00000000e+00, 7.15600000e+03],
+        [1.90000000e+01, 3.00000000e+00, 8.19800000e+03],
+    ]
+
+    # Creating a Pandas DataFrame
+    df = pd.DataFrame(data, columns=["Day", "Month", "Infected people"])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
+    # generate groups
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - df["Infected people"][day]
+        I[day] = I[day - 1] + df["Infected people"][day] - I[day - 1] * alpha
+        R[day] = R[day - 1] + I[day - 1] * alpha
+
+    COVID_Data = np.asarray([np.arange(0, time_range, 1)] +
+                            [S, I, R])
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+    np.savetxt(
+        f"datasets/SIR_Paper_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")

+ 32 - 4
src/problem.py

@@ -1,6 +1,7 @@
 import torch
 from .dataset import PandemicDataset
 
+
 class PandemicProblem:
     def __init__(self, data: PandemicDataset) -> None:
         """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
@@ -18,14 +19,14 @@ class PandemicProblem:
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
         assert self._gradients != None, 'Gradientmatrix need to be defined'
-        
 
-    def def_grad_matrix(self, number:int):
+    def def_grad_matrix(self, number: int):
         assert self._gradients == None, 'Gradientmatrix is already defined'
         self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
         for i in range(number):
             self._gradients[i][:, i] = 1
 
+
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
@@ -53,8 +54,36 @@ class SIRProblem(PandemicProblem):
         return S_residual, I_residual, R_residual
 
 
+class SIRAlphaProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset, alpha):
+        super().__init__(data)
+        self.alpha = alpha
+
+    def residual(self, SIR_pred, beta):
+        super().residual()
+        SIR_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[2], retain_graph=True)
+        dRdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+
+        S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
+        R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
+
+        return S_residual, I_residual, R_residual
+
+
 class ReducedSIRProblem(PandemicProblem):
-    def __init__(self, data: PandemicDataset, alpha:float):
+    def __init__(self, data: PandemicDataset, alpha: float):
         super().__init__(data)
         self.alpha = alpha
 
@@ -72,4 +101,3 @@ class ReducedSIRProblem(PandemicProblem):
 
         I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         return I_residual
-

+ 56 - 56
states_training.py

@@ -7,7 +7,7 @@ from src.dataset import PandemicDataset, Norms
 from src.problem import ReducedSIRProblem
 from src.dinn import DINN, Scheduler, Activation
 
-ALPHA = [1/14, 1/5]
+ALPHA = [1 / 14, 1 / 5]
 DO_STATES = True
 DO_SYNTHETIC = False
 
@@ -16,66 +16,66 @@ ITERATIONS = 13
 state_starting_index = 0
 
 if "1" in sys.argv:
-     state_starting_index = 8
-
-STATE_LOOKUP = {'Schleswig_Holstein' : 2897000,
-                'Hamburg' : 1841000, 
-                'Niedersachsen' : 7982000, 
-                'Bremen' : 569352,
-                'Nordrhein_Westfalen' : 17930000,
-                'Hessen' : 6266000,
-                'Rheinland_Pfalz' : 4085000,
-                'Baden_Wuerttemberg' : 11070000,
-                'Bayern' : 13080000,
-                'Saarland' : 990509,
-                'Berlin' : 3645000,
-                'Brandenburg' : 2641000,
-                'Mecklenburg_Vorpommern' : 1610000,
-                'Sachsen' : 4078000,
-                'Sachsen_Anhalt' : 2208000,
-                'Thueringen' : 2143000}
+    state_starting_index = 8
+
+STATE_LOOKUP = {'Schleswig_Holstein': 2897000,
+                'Hamburg': 1841000,
+                'Niedersachsen': 7982000,
+                'Bremen': 569352,
+                'Nordrhein_Westfalen': 17930000,
+                'Hessen': 6266000,
+                'Rheinland_Pfalz': 4085000,
+                'Baden_Wuerttemberg': 11070000,
+                'Bayern': 13080000,
+                'Saarland': 990509,
+                'Berlin': 3645000,
+                'Brandenburg': 2641000,
+                'Mecklenburg_Vorpommern': 1610000,
+                'Sachsen': 4078000,
+                'Sachsen_Anhalt': 2208000,
+                'Thueringen': 2143000}
 
 if DO_SYNTHETIC:
-    alpha = 1/3
+    alpha = 1 / 3
     covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
     for i in range(ITERATIONS):
-        dataset = PandemicDataset('Synthetic I', 
-                                    ['I'], 
-                                    7.6e6, 
-                                    *covid_data, 
-                                    norm_name=Norms.CONSTANT, 
-                                    use_scaled_time=True)
+        dataset = PandemicDataset('Synthetic I',
+                                  ['I'],
+                                  7.6e6,
+                                  *covid_data,
+                                  norm_name=Norms.CONSTANT,
+                                  use_scaled_time=True)
 
         problem = ReducedSIRProblem(dataset, alpha)
-        dinn = DINN(2, 
-                    dataset, 
-                    [], 
-                    problem, 
-                    None, 
-                    state_variables=['R_t'], 
-                    hidden_size=100, 
-                    hidden_layers=4, 
+        dinn = DINN(2,
+                    dataset,
+                    [],
+                    problem,
+                    None,
+                    state_variables=['R_t'],
+                    hidden_size=100,
+                    hidden_layers=4,
                     activation_layer=torch.nn.Tanh(),
                     activation_output=Activation.POWER)
 
-        dinn.configure_training(1e-3, 
-                                20000, 
-                                scheduler_class=Scheduler.POLYNOMIAL, 
+        dinn.configure_training(1e-3,
+                                20000,
+                                scheduler_class=Scheduler.POLYNOMIAL,
                                 lambda_physics=1e-6,
                                 verbose=True)
         dinn.train(verbose=True, do_split_training=True)
 
         dinn.save_training_process(f'synthetic_{i}')
-        #r_t = dinn.get_output(1).detach().cpu().numpy()
+        # r_t = dinn.get_output(1).detach().cpu().numpy()
 
-        #with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
-                #writer = csv.writer(csvfile, delimiter=',')
-                #writer.writerow(r_t)
+        # with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
+        # writer = csv.writer(csvfile, delimiter=',')
+        # writer.writerow(r_t)
 
 for iteration in range(ITERATIONS):
     if iteration <= 2:
-         print('skip first three iteration, as it was already done')
-         continue
+        print('skip first three iteration, as it was already done')
+        continue
     if DO_STATES:
         for state_idx in range(state_starting_index, state_starting_index + 8):
             state = list(STATE_LOOKUP.keys())[state_idx]
@@ -91,22 +91,22 @@ for iteration in range(ITERATIONS):
 
                 problem = ReducedSIRProblem(dataset, alpha)
 
-                dinn = DINN(2, 
-                            dataset, 
-                            [], 
-                            problem, 
-                            None, 
-                            state_variables=['R_t'], 
-                            hidden_size=100, 
-                            hidden_layers=4, 
+                dinn = DINN(2,
+                            dataset,
+                            [],
+                            problem,
+                            None,
+                            state_variables=['R_t'],
+                            hidden_size=100,
+                            hidden_layers=4,
                             activation_layer=torch.nn.Tanh(),
                             activation_output=Activation.POWER)
 
-                dinn.configure_training(1e-3, 
-                                        25000, 
+                dinn.configure_training(1e-3,
+                                        25000,
                                         scheduler_class=Scheduler.POLYNOMIAL,
-                                        lambda_obs=1e2, 
-                                        lambda_physics=1e-6, 
+                                        lambda_obs=1e2,
+                                        lambda_physics=1e-6,
                                         verbose=True)
                 dinn.train(verbose=True, do_split_training=True)
 
@@ -115,4 +115,4 @@ for iteration in range(ITERATIONS):
                 r_t = dinn.get_output(1).detach().cpu().numpy()
                 with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
                     writer = csv.writer(csvfile, delimiter=',')
-                    writer.writerow(r_t)
+                    writer.writerow(r_t)