Răsfoiți Sursa

add plotter solution

phillip.rothenbeck 1 an în urmă
părinte
comite
10b80b615a
3 a modificat fișierele cu 21 adăugiri și 51 ștergeri
  1. 1 2
      dataset_vis.ipynb
  2. 17 45
      datasets/synthetic_data.py
  3. 3 4
      synth_data_vis.ipynb

Fișier diff suprimat deoarece este prea mare
+ 1 - 2
dataset_vis.ipynb


+ 17 - 45
datasets/synthetic_data.py

@@ -1,34 +1,23 @@
 
 
 import numpy as np
 import numpy as np
-from scipy.integrate import odeint
-import matplotlib.pyplot as plt
-from matplotlib import rcParams
-
-FONT_COLOR = '#595959'
-SUSCEPTIBLE = '#6399f7'
-INFECTIOUS = '#f56262'
-REMOVED = '#83eb5e'
 
 
-rcParams['font.family'] = 'Comfortaa'
-rcParams['font.size'] = 12
-
-rcParams['text.color'] = FONT_COLOR
-rcParams['axes.labelcolor'] = FONT_COLOR
-rcParams['xtick.color'] = FONT_COLOR
-rcParams['ytick.color'] = FONT_COLOR
+from scipy.integrate import odeint
 
 
+from src.plotter import Plotter
 
 
 class SyntheticDeseaseData:
 class SyntheticDeseaseData:
-    def __init__(self, simulation_time, time_points):
+    def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
         """This class is the parent class for every class, that is supposed to generate synthetic data.
         """This class is the parent class for every class, that is supposed to generate synthetic data.
 
 
         Args:
         Args:
             simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
             simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
             time_points (int): Number of time sample points.
             time_points (int): Number of time sample points.
+            plotter (Plotter): Plotter object to plot dataset curves.
         """
         """
         self.t = np.linspace(0, simulation_time, time_points)
         self.t = np.linspace(0, simulation_time, time_points)
         self.data = None
         self.data = None
         self.generated = False
         self.generated = False
+        self.plotter = plotter
 
 
     def differential_eq(self):
     def differential_eq(self):
         """In this function the differential equation of the model will be implemented.
         """In this function the differential equation of the model will be implemented.
@@ -40,45 +29,27 @@ class SyntheticDeseaseData:
         """
         """
         self.generated = True
         self.generated = True
 
 
-    def plot(self, labels: tuple, title=''):
+    def plot(self, labels: tuple, title:str):
         """Plot the data which was generated.
         """Plot the data which was generated.
 
 
         Args:
         Args:
             labels (tuple): The names of each curve.
             labels (tuple): The names of each curve.
+            title (str): The name of the plot.
         """
         """
+        assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         if self.generated:
         if self.generated:
-            fig = plt.figure(figsize=(6,6))
-            ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
-            ax.set_facecolor('xkcd:white')
-
-            color = (SUSCEPTIBLE, INFECTIOUS, REMOVED, 'red')
-            for i in range(len(self.data)):
-                # plot each group
-                ax.plot(self.t, self.data[i], color[i], lw=3, label=labels[i])
-
-            ax.set_xlabel('Time in days')
-            ax.set_ylabel('Amount of people')
-            ax.yaxis.set_tick_params(length=0)
-            ax.xaxis.set_tick_params(length=0)
-            ax.grid(which='major', c='black', lw=0.2, ls='-')
-            legend = ax.legend()
-            legend.get_frame().set_alpha(0.5)
-            for spine in ('top', 'right', 'bottom', 'left'):
-                ax.spines[spine].set_visible(False)
-            if title == '':
-                plt.savefig('visualizations/synthetic_dataset.png')
-            else:
-                plt.savefig('visualizations/' + title + '.png', transparent=True)
+            self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
         else: 
         else: 
             print('Data has to be generated before plotting!') # Fabienne war hier
             print('Data has to be generated before plotting!') # Fabienne war hier
 
 
         
         
 
 
 class SIR(SyntheticDeseaseData):
 class SIR(SyntheticDeseaseData):
-    def __init__(self, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
         """This class is able to generate synthetic data for the SIR model.
         """This class is able to generate synthetic data for the SIR model.
 
 
         Args:
         Args:
+            plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             R_0 (int, optional): Initial size of the removed group. Defaults to 0.
             R_0 (int, optional): Initial size of the removed group. Defaults to 0.
@@ -95,7 +66,7 @@ class SIR(SyntheticDeseaseData):
         self.alpha = alpha
         self.alpha = alpha
         self.beta = beta
         self.beta = beta
 
 
-        super().__init__(simulation_time, time_points)
+        super().__init__(simulation_time, time_points, plotter)
 
 
     def differential_eq(self, y, t, alpha, beta):
     def differential_eq(self, y, t, alpha, beta):
         """In this function implements the differential equation of the SIR model will be implemented.
         """In this function implements the differential equation of the SIR model will be implemented.
@@ -137,10 +108,11 @@ class SIR(SyntheticDeseaseData):
         
         
 
 
 class SIDR(SyntheticDeseaseData):
 class SIDR(SyntheticDeseaseData):
-    def __init__(self, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
+    def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
         """This class is able to generate synthetic data for the SIDR model.
         """This class is able to generate synthetic data for the SIDR model.
 
 
         Args:
         Args:
+            plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
             D_0 (int, optional): Initial size of the dead group. Defaults to 0.
             D_0 (int, optional): Initial size of the dead group. Defaults to 0.
@@ -161,7 +133,7 @@ class SIDR(SyntheticDeseaseData):
         self.beta = beta
         self.beta = beta
         self.gamma = gamma
         self.gamma = gamma
 
 
-        super().__init__(simulation_time, time_points)
+        super().__init__(simulation_time, time_points, plotter)
     
     
     def differential_eq(self, y, t, alpha, beta, gamma):
     def differential_eq(self, y, t, alpha, beta, gamma):
         """In this function implements the differential equation of the SIDR model will be implemented.
         """In this function implements the differential equation of the SIDR model will be implemented.
@@ -190,10 +162,10 @@ class SIDR(SyntheticDeseaseData):
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
         super().generate()
         super().generate()
 
 
-    def plot(self):
+    def plot(self, title):
         """Plot the data which was generated.
         """Plot the data which was generated.
         """
         """
-        super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'))
+        super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'), title=title)
 
 
     def save(self, name=''):
     def save(self, name=''):
         if self.generated:
         if self.generated:

Fișier diff suprimat deoarece este prea mare
+ 3 - 4
synth_data_vis.ipynb


Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff