Browse Source

get data from paper

phillip.rothenbeck 4 months ago
parent
commit
84e7847058
2 changed files with 209 additions and 76 deletions
  1. 33 27
      src/preprocessing/synthetic_data.py
  2. 176 49
      src/preprocessing/transform_data.py

+ 33 - 27
src/preprocessing/synthetic_data.py

@@ -5,8 +5,9 @@ from scipy.integrate import odeint
 
 
 from src.plotter import Plotter
 from src.plotter import Plotter
 
 
+
 class SyntheticDeseaseData:
 class SyntheticDeseaseData:
-    def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
+    def __init__(self, simulation_time: int, time_points: int, plotter: Plotter):
         """This class is the parent class for every class, that is supposed to generate synthetic data.
         """This class is the parent class for every class, that is supposed to generate synthetic data.
 
 
         Args:
         Args:
@@ -29,7 +30,7 @@ class SyntheticDeseaseData:
         """
         """
         self.generated = True
         self.generated = True
 
 
-    def plot(self, labels: tuple, title:str, file_name:str):
+    def plot(self, labels: tuple, title: str, file_name: str, leave_out_indices):
         """Plot the data which was generated.
         """Plot the data which was generated.
 
 
         Args:
         Args:
@@ -37,13 +38,20 @@ class SyntheticDeseaseData:
             title (str): The name of the plot.
             title (str): The name of the plot.
         """
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
+        groups = []
+        used_labels = []
+        for i, group in enumerate(self.data):
+            if not i in leave_out_indices:
+                groups.append(group)
+                used_labels.append(labels[i])
         if self.generated:
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
-        else: 
+            self.plotter.plot(self.t, groups, used_labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
+
 class SIR(SyntheticDeseaseData):
 class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
         """This class is able to generate synthetic data for the SIR model.
         """This class is able to generate synthetic data for the SIR model.
 
 
         Args:
         Args:
@@ -78,8 +86,8 @@ class SIR(SyntheticDeseaseData):
             tuple: Change amount for each group.
             tuple: Change amount for each group.
         """
         """
         S, I, _ = y
         S, I, _ = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dSdt = -self.beta * ((S * I) / self.N)  # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I  # self.beta * S * I - self.alpha * I
         dRdt = self.alpha * I
         dRdt = self.alpha * I
         return dSdt, dIdt, dRdt
         return dSdt, dIdt, dRdt
 
 
@@ -90,21 +98,22 @@ class SIR(SyntheticDeseaseData):
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
         super().generate()
 
 
-    def plot(self, title='', file_name='SIR_plot'):
+    def plot(self, title='', file_name='SIR_plot', leave_out_indices=[]):
         """Plot the data which was generated.
         """Plot the data which was generated.
         """
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name, leave_out_indices=leave_out_indices)
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
 
             np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
+
 class I(SyntheticDeseaseData):
 class I(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
+    def __init__(self, plotter: Plotter, N: int, C: int, I_0=1, time_points=100, alpha=1 / 3) -> None:
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
 
         Args:
         Args:
@@ -119,7 +128,7 @@ class I(SyntheticDeseaseData):
         self.N = N
         self.N = N
         self.C = C
         self.C = C
         self.I_0 = I_0
         self.I_0 = I_0
- 
+
         self.alpha = alpha
         self.alpha = alpha
 
 
         self.t = np.linspace(0, 1, time_points)
         self.t = np.linspace(0, 1, time_points)
@@ -129,14 +138,12 @@ class I(SyntheticDeseaseData):
         self.data = None
         self.data = None
         self.generated = False
         self.generated = False
         self.plotter = plotter
         self.plotter = plotter
-        
+
     def R_t(self, t):
     def R_t(self, t):
         descaled_t = t * self.t_f
         descaled_t = t * self.t_f
         # if descaled_t < threshold1:
         # if descaled_t < threshold1:
         return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
         return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
 
-
-            
     def differential_eq(self, I, t):
     def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
         """In this function implements the differential equation of the SIR model will be implemented.
 
 
@@ -153,10 +160,10 @@ class I(SyntheticDeseaseData):
     def generate(self):
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """This funtion generates the data for this configuration of the SIR model.
         """
         """
-        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = odeint(self.differential_eq, self.I_0 / self.C, self.t).T
         self.data = self.data[0] * self.C
         self.data = self.data[0] * self.C
         self.t_counter = 0
         self.t_counter = 0
-        self.generated =True
+        self.generated = True
 
 
     def plot(self, title='', file_name=''):
     def plot(self, title='', file_name=''):
         """Plot the data which was generated.
         """Plot the data which was generated.
@@ -167,21 +174,20 @@ class I(SyntheticDeseaseData):
             for time in self.t:
             for time in self.t:
                 self.reproduction_value.append(self.R_t(time))
                 self.reproduction_value.append(self.R_t(time))
             self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
             self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t_save, self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data])
 
 
             np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')
 
 
-        
 
 
 class SIDR(SyntheticDeseaseData):
 class SIDR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
+    def __init__(self, plotter: Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
         """This class is able to generate synthetic data for the SIDR model.
         """This class is able to generate synthetic data for the SIDR model.
 
 
         Args:
         Args:
@@ -207,7 +213,7 @@ class SIDR(SyntheticDeseaseData):
         self.gamma = gamma
         self.gamma = gamma
 
 
         super().__init__(simulation_time, time_points, plotter)
         super().__init__(simulation_time, time_points, plotter)
-    
+
     def differential_eq(self, y, t, alpha, beta, gamma):
     def differential_eq(self, y, t, alpha, beta, gamma):
         """In this function implements the differential equation of the SIDR model will be implemented.
         """In this function implements the differential equation of the SIDR model will be implemented.
 
 
@@ -223,7 +229,7 @@ class SIDR(SyntheticDeseaseData):
         """
         """
         S, I, D, R = y
         S, I, D, R = y
         dSdt = - (self.alpha / self.N) * S * I
         dSdt = - (self.alpha / self.N) * S * I
-        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I 
+        dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
         dDdt = self.gamma * I
         dDdt = self.gamma * I
         dRdt = self.beta * I
         dRdt = self.beta * I
         return dSdt, dIdt, dDdt, dRdt
         return dSdt, dIdt, dDdt, dRdt
@@ -242,8 +248,8 @@ class SIDR(SyntheticDeseaseData):
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t, *self.data])
 
 
             np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
             np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
-        else: 
+        else:
             print('Data has to be generated before plotting!')
             print('Data has to be generated before plotting!')

+ 176 - 49
src/preprocessing/transform_data.py

@@ -1,72 +1,124 @@
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
+from datetime import date, timedelta
 
 
 from src.plotter import Plotter
 from src.plotter import Plotter
 
 
-state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                'Hamburg' : (2, 1841000), 
-                'Niedersachsen' : (3, 7982000), 
-                'Bremen' : (4, 569352),
-                'Nordrhein-Westfalen' : (5, 17930000),
-                'Hessen' : (6, 6266000),
-                'Rheinland-Pfalz' : (7, 4085000),
-                'Baden-Württemberg' : (8, 11070000),
-                'Bayern' : (9, 13080000),
-                'Saarland' : (10, 990509),
-                'Berlin' : (11, 3645000),
-                'Brandenburg' : (12, 2641000),
-                'Mecklenburg-Vorpommern' : (13, 1610000),
-                'Sachsen' : (14, 4078000),
-                'Sachsen-Anhalt' : (15, 2208000),
-                'Thüringen' : (16, 2143000)}
-
-def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein': (1, 2897000),
+                'Hamburg': (2, 1841000),
+                'Niedersachsen': (3, 7982000),
+                'Bremen': (4, 569352),
+                'Nordrhein-Westfalen': (5, 17930000),
+                'Hessen': (6, 6266000),
+                'Rheinland-Pfalz': (7, 4085000),
+                'Baden-Württemberg': (8, 11070000),
+                'Bayern': (9, 13080000),
+                'Saarland': (10, 990509),
+                'Berlin': (11, 3645000),
+                'Brandenburg': (12, 2641000),
+                'Mecklenburg-Vorpommern': (13, 1610000),
+                'Sachsen': (14, 4078000),
+                'Sachsen-Anhalt': (15, 2208000),
+                'Thüringen': (16, 2143000)}
+
+
+def daterange(start_date: date, end_date: date):
+    days = int((end_date - start_date).days)
+    for n in range(days):
+        yield start_date + timedelta(n)
+
+
+def transform_jh_germany_data(plotter: Plotter,
+                              time_range=50,
+                              sample_rate=1,
+                              model='SIR'):
+    N = 83100000
+    state_name = 'Germany'
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+
+    # extract data
+    data_directory = 'datasets/COVID-19/csse_covid_19_data/csse_covid_19_daily_reports'
+    start_date = date(2020, 1, 31)
+    end_date = date(2020, 3, 20)
+    for i, single_date in enumerate(daterange(start_date, end_date)):
+        file_date = single_date.strftime("%m-%d-%Y")
+        date_df = pd.read_csv(data_directory + "/" + file_date + ".csv")
+        date_df = date_df.loc[date_df['Country/Region'] == state_name]
+
+        infections[i] = date_df['Confirmed'].fillna(0).astype(int)
+        deaths[i] = date_df['Deaths'].fillna(0).astype(int)
+        recoveries[i] = date_df['Recovered'].fillna(0).astype(int)
+
+    S, I, R = np.zeros(infections.shape[0]), np.zeros(
+        infections.shape[0]), np.zeros(infections.shape[0])
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - infections[day]
+        I[day] = I[day - 1] + infections[day] - deaths[day] - recoveries[day]
+        R[day] = R[day - 1] + deaths[day] + recoveries[day]
+        if I[day] < 0:
+            I[day] = 0
+
+    t = np.arange(0, time_range, 1)
+
+    plotter.plot(t, [I, R], ["I", "R"], "JH_data", "JH Data", (6, 6))
+
+    groups = [S, I, R]
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_JH_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+
+
+def transform_data(plotter: Plotter, alpha=1 / 14, state_name='Germany', time_range=1200, sample_rate=1, model='SIR'):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
 
     Args:
     Args:
         plotter (Plotter): Plotter object to plot dataset curves.
         plotter (Plotter): Plotter object to plot dataset curves.
         dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
         dataset_path (str, optional): Path to the dataset directory. Defaults to 'datasets/COVID-19-Todesfaelle_in_Deutschland/'.
-        plot_name (str, optional): Name of the plot file. Defaults to ''.
-        plot_title (str, optional): Title of the plot. Defaults to ''.
         sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
         sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
         exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
         exclude (list, optional): List of groups that are to excluded from the plot. Defaults to [].
-        plot_size (tuple, optional): Size of the plot in (x, y) format. Defaults to (12,6).
-        yscale_log (bool, optional): Controls if the y axis of the plot will be scaled in log scale. Defaults to False.
-        plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     """
     # read the data
     # read the data
 
 
-
     infections = np.zeros(time_range)
     infections = np.zeros(time_range)
     deaths = np.zeros(time_range)
     deaths = np.zeros(time_range)
     recoveries = np.zeros(time_range)
     recoveries = np.zeros(time_range)
     if state_name == 'Germany':
     if state_name == 'Germany':
-        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        df = pd.read_csv(
+            'datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
         N = 83100000
         N = 83100000
         infections[0] = df['Faelle_gesamt'][0]
         infections[0] = df['Faelle_gesamt'][0]
         deaths[0] = df['Todesfaelle_neu'][0]
         deaths[0] = df['Todesfaelle_neu'][0]
 
 
         recovery_queue = np.zeros(14)
         recovery_queue = np.zeros(14)
         for i in range(1, time_range):
         for i in range(1, time_range):
-            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i - 1]
             deaths[i] = df['Todesfaelle_neu'][i]
             deaths[i] = df['Todesfaelle_neu'][i]
             recoveries[i] = recovery_queue[0]
             recoveries[i] = recovery_queue[0]
 
 
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
             recovery_queue[-1] = infections[i]
     else:
     else:
-        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        df = pd.read_csv(
+            'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
         state_ID, N = state_lookup[state_name]
         state_ID, N = state_lookup[state_name]
 
 
         # single out a state
         # single out a state
         state_IDs = df['IdLandkreis'] // 1000
         state_IDs = df['IdLandkreis'] // 1000
         df = df.loc[state_IDs == state_ID]
         df = df.loc[state_IDs == state_ID]
 
 
-        # sort entries by state
+        # sort entries by date
         df = df.sort_values('Refdatum')
         df = df.sort_values('Refdatum')
         df = df.reset_index(drop=True)
         df = df.reset_index(drop=True)
 
 
-        # collect cases    
+        # collect cases
         entry_idx = 0
         entry_idx = 0
         day = 0
         day = 0
         date = df['Refdatum'][entry_idx]
         date = df['Refdatum'][entry_idx]
@@ -78,7 +130,8 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
                 deaths[day] += df['AnzahlTodesfall'][entry_idx]
                 deaths[day] += df['AnzahlTodesfall'][entry_idx]
                 entry_idx += 1
                 entry_idx += 1
             # move day index by difference between the current and next date
             # move day index by difference between the current and next date
-            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            day += (pd.to_datetime(df['Refdatum']
+                    [entry_idx]) - pd.to_datetime(date)).days
             date = df['Refdatum'][entry_idx]
             date = df['Refdatum'][entry_idx]
 
 
         recovery_queue = np.zeros(14)
         recovery_queue = np.zeros(14)
@@ -89,48 +142,122 @@ def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[:-1] = recovery_queue[1:]
             recovery_queue[-1] = infections[i]
             recovery_queue[-1] = infections[i]
             week_counter -= 1
             week_counter -= 1
-        
+
     df = df.drop(df.index[time_range:])
     df = df.drop(df.index[time_range:])
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     # generate groups
     S[0] = N - infections[0]
     S[0] = N - infections[0]
     I[0] = infections[0]
     I[0] = infections[0]
     R[0] = 0
     R[0] = 0
     if model == 'I':
     if model == 'I':
         for day in range(1, time_range):
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-            R[day] = R[day-1] + I[day-1] * alpha
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - I[day - 1] * alpha
+            R[day] = R[day - 1] + I[day - 1] * alpha
     else:
     else:
         for day in range(1, time_range):
         for day in range(1, time_range):
-            S[day] = S[day-1] - infections[day]
-            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
-            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            S[day] = S[day - 1] - infections[day]
+            I[day] = I[day - 1] + infections[day] - \
+                deaths[day] - recoveries[day]
+            R[day] = R[day - 1] + deaths[day] + recoveries[day]
             if I[day] < 0:
             if I[day] < 0:
                 I[day] = 0
                 I[day] = 0
-    
+
     t = np.arange(0, time_range, 1)
     t = np.arange(0, time_range, 1)
 
 
     # select, which group is to be outputted
     # select, which group is to be outputted
     groups = []
     groups = []
     if 'S' in model:
     if 'S' in model:
         groups.append(S)
         groups.append(S)
-    
+
     if 'I' in model:
     if 'I' in model:
         groups.append(I)
         groups.append(I)
 
 
     if 'R' in model:
     if 'R' in model:
         groups.append(R)
         groups.append(R)
 
 
-    plotter.plot(t, 
-                 groups, 
-                 [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
-                 state_name, 
-                 (6,6), 
-                 xlabel='time / days', 
+    plotter.plot(t,
+                 groups,
+                 [*model],
+                 state_name.replace(' ', '_').replace(
+                     '-', '_').replace('ü', 'ue') + f"_{model}" + f"_{int(1/alpha)}",
+                 state_name,
+                 (6, 6),
+                 xlabel='time / days',
                  ylabel='amount of people')
                  ylabel='amount of people')
 
 
-    COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
+    COVID_Data = np.asarray([t[0::sample_rate]] +
+                            [group[0::sample_rate] for group in groups])
+
+    np.savetxt(
+        f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+
+
+def transform_paper_data():
+    N = 70000000
+    time_range = 36
+    alpha = 0.07
+    state_name = 'Germany'
+
+    infections = np.zeros(time_range)
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    # Data
+    data = [
+        [1.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.60000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.70000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.80000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [1.90000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.00000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.10000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.20000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.30000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.40000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.50000000e+01, 2.00000000e+00, 1.50000000e+01],
+        [2.60000000e+01, 2.00000000e+00, 1.70000000e+01],
+        [2.70000000e+01, 2.00000000e+00, 2.10000000e+01],
+        [2.80000000e+01, 2.00000000e+00, 4.70000000e+01],
+        [2.90000000e+01, 2.00000000e+00, 5.70000000e+01],
+        [1.00000000e+00, 3.00000000e+00, 1.11000000e+02],
+        [2.00000000e+00, 3.00000000e+00, 1.29000000e+02],
+        [3.00000000e+00, 3.00000000e+00, 1.57000000e+02],
+        [4.00000000e+00, 3.00000000e+00, 1.96000000e+02],
+        [5.00000000e+00, 3.00000000e+00, 2.62000000e+02],
+        [6.00000000e+00, 3.00000000e+00, 4.00000000e+02],
+        [7.00000000e+00, 3.00000000e+00, 6.84000000e+02],
+        [8.00000000e+00, 3.00000000e+00, 8.47000000e+02],
+        [9.00000000e+00, 3.00000000e+00, 9.02000000e+02],
+        [1.00000000e+01, 3.00000000e+00, 1.13900000e+03],
+        [1.10000000e+01, 3.00000000e+00, 1.29600000e+03],
+        [1.20000000e+01, 3.00000000e+00, 1.56700000e+03],
+        [1.30000000e+01, 3.00000000e+00, 2.36900000e+03],
+        [1.40000000e+01, 3.00000000e+00, 3.06200000e+03],
+        [1.50000000e+01, 3.00000000e+00, 3.79500000e+03],
+        [1.60000000e+01, 3.00000000e+00, 4.83800000e+03],
+        [1.70000000e+01, 3.00000000e+00, 6.01200000e+03],
+        [1.80000000e+01, 3.00000000e+00, 7.15600000e+03],
+        [1.90000000e+01, 3.00000000e+00, 8.19800000e+03],
+    ]
+
+    # Creating a Pandas DataFrame
+    df = pd.DataFrame(data, columns=["Day", "Month", "Infected people"])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(
+        df.shape[0]), np.zeros(df.shape[0])
+    # generate groups
+    S[0] = N - infections[0]
+    I[0] = infections[0]
+    R[0] = 0
+    for day in range(1, time_range):
+        S[day] = S[day - 1] - df["Infected people"][day]
+        I[day] = I[day - 1] + df["Infected people"][day] - I[day - 1] * alpha
+        R[day] = R[day - 1] + I[day - 1] * alpha
+
+    COVID_Data = np.asarray([np.arange(0, time_range, 1)] +
+                            [S, I, R])
 
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")
+    np.savetxt(
+        f"datasets/SIR_Paper_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")