Explorar el Código

add generalized data transformation algorithm

phillip.rothenbeck hace 4 meses
padre
commit
3b05d7d641
Se han modificado 1 ficheros con 88 adiciones y 138 borrados
  1. 88 138
      src/preprocessing/transform_data.py

+ 88 - 138
src/preprocessing/transform_data.py

@@ -3,7 +3,24 @@ import pandas as pd
 
 from src.plotter import Plotter
 
-def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-Todesfaelle_in_Deutschland/', plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
+state_lookup = {'Schleswig Holstein' : (1, 2897000),
+                'Hamburg' : (2, 1841000), 
+                'Niedersachsen' : (3, 7982000), 
+                'Bremen' : (4, 569352),
+                'Nordrhein-Westfalen' : (5, 17930000),
+                'Hessen' : (6, 6266000),
+                'Rheinland-Pfalz' : (7, 4085000),
+                'Baden-Württemberg' : (8, 11070000),
+                'Bayern' : (9, 13080000),
+                'Saarland' : (10, 990509),
+                'Berlin' : (11, 3645000),
+                'Brandenburg' : (12, 2641000),
+                'Mecklenburg-Vorpommern' : (13, 1610000),
+                'Sachsen' : (14, 4078000),
+                'Sachsen-Anhalt' : (15, 2208000),
+                'Thüringen' : (16, 2143000)}
+
+def transform_data(plotter:Plotter, alpha=1/14, state_name='Germany', time_range=1200, plot_name='', plot_title='', sample_rate=1, model='SIR', plot_size=(12,6), yscale_log=False, plot_legend=True):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
 
     Args:
@@ -18,147 +35,80 @@ def transform_general_to_SIR(plotter:Plotter, dataset_path='datasets/COVID-19-To
         plot_legend (bool, optional): Controls if the legend is to be plotted. Defaults to True.
     """
     # read the data
-    df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
-
-    df = df.drop(df.index[1200:])
-    
-    # population of germany at the end of 2019
-    N = 83100000
-    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
-
-    # S_0 = N - I_0
-    S[0] = N - df['Faelle_gesamt'][0]
-    # I_0 = overall cases at the day - overall death cases at the day
-    I[0] = df['Faelle_gesamt'][0] - df['Todesfaelle_gesamt'][0]
-    # R_0 = overall death cases at the day
-    R[0] = df['Todesfaelle_gesamt'][0]
-
-    # the recovery time is 14 days
-    recovery_queue = np.zeros(14)
-    
-    for day in range(1, df.shape[0]):
-        infections = df['Faelle_gesamt'][day] - df['Faelle_gesamt'][day-1]
-        deaths = df['Todesfaelle_neu'][day]
-        recoveries = recovery_queue[0]
-
-        S[day] = S[day-1] - infections
-        I[day] = I[day-1] + infections - deaths - recoveries
-        R[day] = R[day-1] + deaths + recoveries
-
-        # update recovery queue
-        if I[day] < 0:
-            recovery_queue[-1] -= I[day] 
-            I[day] = 0
-
-        recovery_queue[:-1] = recovery_queue[1:]
-        recovery_queue[-1] = infections
-
-    t = np.arange(0, df.shape[0], 1)
-    if plotter != None:
-        # plot graphs
-        plots = []
-        labels = []
-
-        if 'S' not in exclude:
-            plots.append(S)
-            labels.append('S')
-        
-        if 'I' not in exclude:
-            plots.append(I)
-            labels.append('I')
-
-        if 'R' not in exclude:
-            plots.append(R)
-            labels.append('R')
-
-        plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
-
-    COVID_Data = np.asarray([t[0::sample_rate], 
-                             S[0::sample_rate], 
-                             I[0::sample_rate], 
-                             R[0::sample_rate]]) 
-
-    np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")
-
 
 
-def get_state_cases(county_id, state_id):
-    id = county_id // 1000
-    return id == state_id
-
-def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, time_range=1200, sample_rate=1, dataset_path='datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'):
-    """Transforms the RKI infection cases dataset to a SIR dataset.
-
-    Args:
-        plotter (Plotter): Plotter object to plot dataset curves.
-        state_name (str): Name of the state that is to be singled out in the new dataset.
-        time_range (int, optional): Number of days that will be looked at in the new dataset. Defaults to 1200.
-        sample_rate (int, optional): Sample rate used to sample the timepoints. Defaults to 1.
-        dataset_path (str, optional): Path to the CSV file, where the data is stored. Defaults to 'datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv'.
-    """
-    df = pd.read_csv(dataset_path)
-
-    state_lookup = {'Schleswig Holstein' : (1, 2897000),
-                    'Hamburg' : (2, 1841000), 
-                    'Niedersachsen' : (3, 7982000), 
-                    'Bremen' : (4, 569352),
-                    'Nordrhein-Westfalen' : (5, 17930000),
-                    'Hessen' : (6, 6266000),
-                    'Rheinland-Pfalz' : (7, 4085000),
-                    'Baden-Württemberg' : (8, 11070000),
-                    'Bayern' : (9, 13080000),
-                    'Saarland' : (10, 990509),
-                    'Berlin' : (11, 3645000),
-                    'Brandenburg' : (12, 2641000),
-                    'Mecklenburg-Vorpommern' : (13, 1610000),
-                    'Sachsen' : (14, 4078000),
-                    'Sachsen-Anhalt' : (15, 2208000),
-                    'Thüringen' : (16, 2143000)}
-    state_ID, N = state_lookup[state_name]
-
-    # single out a state
-    state_IDs = df['IdLandkreis'] // 1000
-    state_df = df.loc[state_IDs == state_ID]
-
-    # sort entries by state
-    state_df = state_df.sort_values('Refdatum')
-    state_df = state_df.reset_index(drop=True)
-
-
-    # collect cases    
     infections = np.zeros(time_range)
-    dead = np.zeros(time_range)
-    recovered = np.zeros(time_range)
-    entry_idx = 0
-    day = 0
-    date = state_df['Refdatum'][entry_idx]
-    # check for each date all entries
-    while day < time_range:
-        # use the date sorted characteristic and take all entries with current date
-        while state_df['Refdatum'][entry_idx] == date:
-            # TODO use further parameters
-            infections[day] += state_df['AnzahlFall'][entry_idx]
-            dead[day] += state_df['AnzahlTodesfall'][entry_idx]
-            recovered[day] += state_df['AnzahlGenesen'][entry_idx]
-            entry_idx += 1
-        # move day index by difference between the current and next date
-        day += (pd.to_datetime(state_df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
-        date = state_df['Refdatum'][entry_idx]
-
-    S = np.zeros(time_range)
-    I = np.zeros(time_range)
-    R = np.zeros(time_range)
-
+    deaths = np.zeros(time_range)
+    recoveries = np.zeros(time_range)
+    if state_name == 'Germany':
+        df = pd.read_csv('datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv')
+        N = 83100000
+        infections[0] = df['Faelle_gesamt'][0]
+        deaths[0] = df['Todesfaelle_neu'][0]
+
+        recovery_queue = np.zeros(14)
+        for i in range(1, time_range):
+            infections[i] = df['Faelle_gesamt'][i] - df['Faelle_gesamt'][i-1]
+            deaths[i] = df['Todesfaelle_neu'][i]
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+    else:
+        df = pd.read_csv('datasets/state_data/Aktuell_Deutschland_SarsCov2_Infektionen.csv')
+        state_ID, N = state_lookup[state_name]
+
+        # single out a state
+        state_IDs = df['IdLandkreis'] // 1000
+        df = df.loc[state_IDs == state_ID]
+
+        # sort entries by state
+        df = df.sort_values('Refdatum')
+        df = df.reset_index(drop=True)
+
+        # collect cases    
+        entry_idx = 0
+        day = 0
+        date = df['Refdatum'][entry_idx]
+        # check for each date all entries
+        while day < time_range:
+            # use the date sorted characteristic and take all entries with current date
+            while df['Refdatum'][entry_idx] == date:
+                infections[day] += df['AnzahlFall'][entry_idx]
+                deaths[day] += df['AnzahlTodesfall'][entry_idx]
+                entry_idx += 1
+            # move day index by difference between the current and next date
+            day += (pd.to_datetime(df['Refdatum'][entry_idx])-pd.to_datetime(date)).days
+            date = df['Refdatum'][entry_idx]
+
+        recovery_queue = np.zeros(14)
+        week_counter = 2
+        for i in range(1, time_range):
+            recoveries[i] = recovery_queue[0]
+
+            recovery_queue[:-1] = recovery_queue[1:]
+            recovery_queue[-1] = infections[i]
+            week_counter -= 1
+        
+    df = df.drop(df.index[time_range:])
+    S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
     # generate groups
     S[0] = N - infections[0]
     I[0] = infections[0]
     R[0] = 0
-
-    for day in range(1, time_range):
-        S[day] = S[day-1] - infections[day]
-        I[day] = I[day-1] + infections[day] - I[day-1] * alpha
-        R[day] = R[day-1] + I[day-1] * alpha
-
+    if model == 'I':
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - I[day-1] * alpha
+            R[day] = R[day-1] + I[day-1] * alpha
+    else:
+        for day in range(1, time_range):
+            S[day] = S[day-1] - infections[day]
+            I[day] = I[day-1] + infections[day] - deaths[day] - recoveries[day]
+            R[day] = R[day-1] + deaths[day] + recoveries[day]
+            if I[day] < 0:
+                I[day] = 0
+    
     t = np.arange(0, time_range, 1)
 
     # select, which group is to be outputted
@@ -175,12 +125,12 @@ def state_based_data(plotter:Plotter, state_name:str, model='SIR', alpha=1/14, t
     plotter.plot(t, 
                  groups, 
                  [*model], 
-                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue'), 
-                 state_name +' SI', 
+                 state_name.replace(' ', '_').replace('-', '_').replace('ü','ue') + f"_{model}" + f"_{int(1/alpha)}", 
+                 state_name, 
                  (6,6), 
                  xlabel='time / days', 
                  ylabel='amount of people')
 
     COVID_Data = np.asarray([t[0::sample_rate]] + [group[0::sample_rate] for group in groups]) 
 
-    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}.csv", COVID_Data, delimiter=",")
+    np.savetxt(f"datasets/{model}_RKI_{state_name.replace(' ', '_').replace('-', '_').replace('ü','ue')}_{sample_rate}_{int(1/alpha)}.csv", COVID_Data, delimiter=",")