Browse Source

add plotter solution

phillip.rothenbeck 1 year ago
parent
commit
293db24a14
3 changed files with 36 additions and 100 deletions
  1. 2 6
      dataset_vis.ipynb
  2. 30 45
      datasets/transform_SIR.py
  3. 4 49
      synth_dinn_sir.ipynb

File diff suppressed because it is too large
+ 2 - 6
dataset_vis.ipynb


+ 30 - 45
datasets/transform_SIR.py

@@ -1,7 +1,7 @@
 import numpy as np
 import pandas as pd
-import matplotlib.pylab as plt
-from matplotlib import rcParams
+
+from src.plotter import Plotter
 
 dataset_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/'
 
@@ -10,14 +10,14 @@ SUSCEPTIBLE = '#6399f7'
 INFECTIOUS = '#f56262'
 REMOVED = '#83eb5e'
 
-def transform_general_data(plot_name: str, plot_title: str, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
+def transform_general_data(plotter:Plotter, plot_name='', plot_title='', sample_rate=1, exclude=[], plot_size=(12,6), yscale_log=False, plot_legend=True):
     """Function to generate the SIR split from the data in the COVID-19-Todesfaelle_in_Deutschland dataset.
     """
     # read the data
     df = pd.read_csv(dataset_path + 'COVID-19-Todesfaelle_Deutschland.csv')
 
     df = df.drop(df.index[1200:])
-
+    
     # population of germany at the end of 2019
     N = 83100000
     S, I, R = np.zeros(df.shape[0]), np.zeros(df.shape[0]), np.zeros(df.shape[0])
@@ -49,45 +49,30 @@ def transform_general_data(plot_name: str, plot_title: str, exclude=[], plot_siz
         recovery_queue[:-1] = recovery_queue[1:]
         recovery_queue[-1] = infections
 
-    # plot graphs
     t = np.arange(0, df.shape[0], 1)
-    rcParams['font.family'] = 'Comfortaa'
-    rcParams['font.size'] = 12
-
-    rcParams['text.color'] = FONT_COLOR
-    rcParams['axes.labelcolor'] = FONT_COLOR
-    rcParams['xtick.color'] = FONT_COLOR
-    rcParams['ytick.color'] = FONT_COLOR
-
-    slide3 = plt.figure(figsize=plot_size)
-    ax = slide3.add_subplot(111, facecolor='#dddddd', axisbelow=True)
-    ax.set_facecolor('xkcd:white')
-
-    if 'S' not in exclude:
-        ax.plot(t, S, label='Susceptible', c=SUSCEPTIBLE, lw=3)
-    
-    if 'I' not in exclude:
-        ax.plot(t, I, label='Infectious', c=INFECTIOUS, lw=3)
-
-    if 'R' not in exclude:
-        ax.plot(t, R, label='Removed', c=REMOVED, lw=3)
-
-    if yscale_log:
-        plt.yscale('log')
-
-    plt.ylabel('amount of poeple')
-    plt.xlabel('time / days')
-    plt.title(plot_title)
-    ax.yaxis.set_tick_params(length=0)
-
-    if plot_legend:
-        plt.legend()
-
-    ax.yaxis.set_tick_params(length=0, which='both')
-    ax.xaxis.set_tick_params(length=0, which='both')
-    ax.grid(which='major', c='black', lw=0.2, ls='-')
-
-    for spine in ('top', 'right', 'bottom', 'left'):
-        ax.spines[spine].set_visible(False)
-
-    slide3.savefig(f'visualizations/{plot_name}.png', transparent=True)
+    if plotter != None:
+        # plot graphs
+        plots = []
+        labels = []
+
+        if 'S' not in exclude:
+            plots.append(S)
+            labels.append('S')
+        
+        if 'I' not in exclude:
+            plots.append(I)
+            labels.append('I')
+
+        if 'R' not in exclude:
+            plots.append(R)
+            labels.append('R')
+
+        plotter.plot(t, plots, labels, plot_name, plot_title, plot_size, y_log_scale=yscale_log, plot_legend=plot_legend, xlabel='time / days', ylabel='amount of poeple')
+
+    COVID_Data = np.asarray([t[0::sample_rate], 
+                             S[0::sample_rate], 
+                             I[0::sample_rate], 
+                             R[0::sample_rate]]) 
+
+    np.savetxt(f"datasets/SIR_RKI_{sample_rate}.csv", COVID_Data, delimiter=",")
+     

File diff suppressed because it is too large
+ 4 - 49
synth_dinn_sir.ipynb


Some files were not shown because too many files changed in this diff