Browse Source

modify plotting

phillip.rothenbeck 1 year ago
parent
commit
eb60653852
3 changed files with 121 additions and 12 deletions
  1. 27 9
      datasets/synthetic_data.py
  2. 54 0
      generate_presi_graphs.py
  3. 40 3
      src/dinn.py

+ 27 - 9
datasets/synthetic_data.py

@@ -2,6 +2,21 @@
 import numpy as np
 from scipy.integrate import odeint
 import matplotlib.pyplot as plt
+from matplotlib import rcParams
+
+FONT_COLOR = '#595959'
+SUSCEPTIBLE = '#6399f7'
+INFECTIOUS = '#f56262'
+REMOVED = '#83eb5e'
+
+rcParams['font.family'] = 'Comfortaa'
+rcParams['font.size'] = 12
+
+rcParams['text.color'] = FONT_COLOR
+rcParams['axes.labelcolor'] = FONT_COLOR
+rcParams['xtick.color'] = FONT_COLOR
+rcParams['ytick.color'] = FONT_COLOR
+
 
 class SyntheticDeseaseData:
     def __init__(self, simulation_time, time_points):
@@ -25,24 +40,24 @@ class SyntheticDeseaseData:
         """
         self.generated = True
 
-    def plot(self, labels: tuple):
+    def plot(self, labels: tuple, title=''):
         """Plot the data which was generated.
 
         Args:
             labels (tuple): The names of each curve.
         """
         if self.generated:
-            fig = plt.figure(figsize=(12,12))
+            fig = plt.figure(figsize=(6,6))
             ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
             ax.set_facecolor('xkcd:white')
 
-            color = ('violet', 'darkgreen', 'blue', 'red')
+            color = (SUSCEPTIBLE, INFECTIOUS, REMOVED, 'red')
             for i in range(len(self.data)):
                 # plot each group
-                ax.plot(self.t, self.data[i], color[i], alpha=0.5, lw=2, label=labels[i], linestyle='dashed')
+                ax.plot(self.t, self.data[i], color[i], lw=3, label=labels[i])
 
-            ax.set_xlabel('Time per days')
-            ax.set_ylabel('Number')
+            ax.set_xlabel('Time in days')
+            ax.set_ylabel('Amount of people')
             ax.yaxis.set_tick_params(length=0)
             ax.xaxis.set_tick_params(length=0)
             ax.grid(which='major', c='black', lw=0.2, ls='-')
@@ -50,7 +65,10 @@ class SyntheticDeseaseData:
             legend.get_frame().set_alpha(0.5)
             for spine in ('top', 'right', 'bottom', 'left'):
                 ax.spines[spine].set_visible(False)
-            plt.savefig('visualizations/synthetic_dataset.png')
+            if title == '':
+                plt.savefig('visualizations/synthetic_dataset.png')
+            else:
+                plt.savefig('visualizations/' + title + '.png', transparent=True)
         else: 
             print('Data has to be generated before plotting!') # Fabienne war hier
 
@@ -104,10 +122,10 @@ class SIR(SyntheticDeseaseData):
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
 
-    def plot(self):
+    def plot(self, title=''):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'))
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
 
     def save(self, name=''):
         if self.generated:

+ 54 - 0
generate_presi_graphs.py

@@ -0,0 +1,54 @@
+import pandas as pd
+import matplotlib.pyplot as plt
+import matplotlib.dates as mdates
+from matplotlib import rcParams
+
+FONT_COLOR = '#595959'
+SUSCEPTIBLE = '#6399f7'
+INFECTIOUS = '#f56262'
+REMOVED = '#83eb5e'
+
+# rki data
+
+rki_data_path = 'datasets/COVID-19-Todesfaelle_in_Deutschland/COVID-19-Todesfaelle_Deutschland.csv'
+
+rki_data = pd.read_csv(rki_data_path)
+rki_data['Berichtsdatum'] = pd.to_datetime(rki_data['Berichtsdatum'], errors='coerce')
+specific_dates = rki_data[rki_data['Berichtsdatum'].dt.is_quarter_start]['Berichtsdatum']
+
+rcParams['font.family'] = 'Comfortaa'
+rcParams['font.size'] = 12
+
+rcParams['text.color'] = FONT_COLOR
+rcParams['axes.labelcolor'] = FONT_COLOR
+rcParams['xtick.color'] = FONT_COLOR
+rcParams['ytick.color'] = FONT_COLOR
+
+slide3 = plt.figure(figsize=(12,6))
+ax = slide3.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+ax.set_facecolor('xkcd:white')
+
+ax.plot(rki_data['Berichtsdatum'], rki_data['Faelle_gesamt'], label='infections', c=INFECTIOUS, lw=3)
+ax.plot(rki_data['Berichtsdatum'], rki_data['Todesfaelle_gesamt'], label='death cases', c=REMOVED, lw=3)
+
+plt.yscale('log')
+plt.ylabel('amount of poeple')
+plt.xlabel('time')
+plt.title('Accumulated cases (RKI Data)')
+ax.yaxis.set_tick_params(length=0)
+
+leg = plt.legend()
+
+ax.yaxis.set_tick_params(length=0, which='both')
+ax.xaxis.set_tick_params(length=0, which='both')
+ax.grid(which='major', c='black', lw=0.2, ls='-')
+
+for spine in ('top', 'right', 'bottom', 'left'):
+    ax.spines[spine].set_visible(False)
+
+plt.gca().set_xticks(specific_dates)
+plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
+
+plt.gcf().autofmt_xdate(rotation=45, ha='center')
+
+slide3.savefig('visualizations/slide3.png', transparent=True)

+ 40 - 3
src/dinn.py

@@ -1,10 +1,16 @@
 import torch
 import numpy as np
 import matplotlib.pyplot as plt
+from matplotlib import rcParams
 
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 
+FONT_COLOR = '#595959'
+SUSCEPTIBLE = '#6399f7'
+INFECTIOUS = '#f56262'
+REMOVED = '#83eb5e'
+
 class DINN:
     class NN(torch.nn.Module):
         def __init__(self, 
@@ -150,7 +156,7 @@ class DINN:
             # append values for plotting
             self.losses[epoch] = loss.item()
             for i, parameter in enumerate(self.parameters_tilda.items()):
-                self.parameters[i][epoch] = parameter[1].item()
+                self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
 
             # print training advancements
             if epoch % 1000 == 0:          
@@ -163,12 +169,23 @@ class DINN:
                     print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
                 print('#################################') 
 
-    def plot_training_graphs(self):
+    def plot_training_graphs(self, ground_truth=[]):
         """Plot the loss graph and the graphs of the advancements of the parameters.
+
+        Args:
+            ground_truth (list): List of the ground truth parameters
         """
         assert self.epochs != None
         epochs = np.arange(0, self.epochs, 1)
 
+        rcParams['font.family'] = 'Comfortaa'
+        rcParams['font.size'] = 12
+
+        rcParams['text.color'] = FONT_COLOR
+        rcParams['axes.labelcolor'] = FONT_COLOR
+        rcParams['xtick.color'] = FONT_COLOR
+        rcParams['ytick.color'] = FONT_COLOR
+
         # plot loss
         plt.plot(epochs, self.losses)
         plt.title('Loss')
@@ -177,8 +194,28 @@ class DINN:
 
         # plot parameters
         for i, parameter in enumerate(self.parameters):
-            plt.plot(epochs, parameter)
+            figure = plt.figure(figsize=(6,6))
+            ax = figure.add_subplot(111, facecolor='#dddddd', axisbelow=True)
+            ax.set_facecolor('xkcd:white')
+
+            ax.plot(epochs, parameter, c=FONT_COLOR, lw=3, label='prediction')
+            if len(ground_truth) > i:
+                ax.axhline(y=ground_truth[i], color=INFECTIOUS, linestyle='-', lw=3, label='ground truth')
+        
+            plt.xlabel('epochs')
             plt.title(list(self.parameters_tilda.items())[i][0])
+            ax.yaxis.set_tick_params(length=0)
+
+            ax.yaxis.set_tick_params(length=0, which='both')
+            ax.xaxis.set_tick_params(length=0, which='both')
+            ax.grid(which='major', c='black', lw=0.2, ls='-')
+
+            plt.legend()
+
+            for spine in ('top', 'right', 'bottom', 'left'):
+                ax.spines[spine].set_visible(False)
+
+            figure.savefig(f'visualizations/{list(self.parameters_tilda.items())[i][0]}.png', transparent=True)
             plt.show()
 
     def to_cuda(self):