|
@@ -1,34 +1,23 @@
|
|
|
|
|
|
import numpy as np
|
|
|
-from scipy.integrate import odeint
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from matplotlib import rcParams
|
|
|
-
|
|
|
-FONT_COLOR = '#595959'
|
|
|
-SUSCEPTIBLE = '#6399f7'
|
|
|
-INFECTIOUS = '#f56262'
|
|
|
-REMOVED = '#83eb5e'
|
|
|
|
|
|
-rcParams['font.family'] = 'Comfortaa'
|
|
|
-rcParams['font.size'] = 12
|
|
|
-
|
|
|
-rcParams['text.color'] = FONT_COLOR
|
|
|
-rcParams['axes.labelcolor'] = FONT_COLOR
|
|
|
-rcParams['xtick.color'] = FONT_COLOR
|
|
|
-rcParams['ytick.color'] = FONT_COLOR
|
|
|
+from scipy.integrate import odeint
|
|
|
|
|
|
+from src.plotter import Plotter
|
|
|
|
|
|
class SyntheticDeseaseData:
|
|
|
- def __init__(self, simulation_time, time_points):
|
|
|
+ def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
|
|
|
"""This class is the parent class for every class, that is supposed to generate synthetic data.
|
|
|
|
|
|
Args:
|
|
|
simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
|
|
|
time_points (int): Number of time sample points.
|
|
|
+ plotter (Plotter): Plotter object to plot dataset curves.
|
|
|
"""
|
|
|
self.t = np.linspace(0, simulation_time, time_points)
|
|
|
self.data = None
|
|
|
self.generated = False
|
|
|
+ self.plotter = plotter
|
|
|
|
|
|
def differential_eq(self):
|
|
|
"""In this function the differential equation of the model will be implemented.
|
|
@@ -40,45 +29,27 @@ class SyntheticDeseaseData:
|
|
|
"""
|
|
|
self.generated = True
|
|
|
|
|
|
- def plot(self, labels: tuple, title=''):
|
|
|
+ def plot(self, labels: tuple, title:str):
|
|
|
"""Plot the data which was generated.
|
|
|
|
|
|
Args:
|
|
|
labels (tuple): The names of each curve.
|
|
|
+ title (str): The name of the plot.
|
|
|
"""
|
|
|
+ assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
|
|
|
if self.generated:
|
|
|
- fig = plt.figure(figsize=(6,6))
|
|
|
- ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
- ax.set_facecolor('xkcd:white')
|
|
|
-
|
|
|
- color = (SUSCEPTIBLE, INFECTIOUS, REMOVED, 'red')
|
|
|
- for i in range(len(self.data)):
|
|
|
- # plot each group
|
|
|
- ax.plot(self.t, self.data[i], color[i], lw=3, label=labels[i])
|
|
|
-
|
|
|
- ax.set_xlabel('Time in days')
|
|
|
- ax.set_ylabel('Amount of people')
|
|
|
- ax.yaxis.set_tick_params(length=0)
|
|
|
- ax.xaxis.set_tick_params(length=0)
|
|
|
- ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
- legend = ax.legend()
|
|
|
- legend.get_frame().set_alpha(0.5)
|
|
|
- for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
- ax.spines[spine].set_visible(False)
|
|
|
- if title == '':
|
|
|
- plt.savefig('visualizations/synthetic_dataset.png')
|
|
|
- else:
|
|
|
- plt.savefig('visualizations/' + title + '.png', transparent=True)
|
|
|
+ self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
|
|
|
else:
|
|
|
print('Data has to be generated before plotting!') # Fabienne war hier
|
|
|
|
|
|
|
|
|
|
|
|
class SIR(SyntheticDeseaseData):
|
|
|
- def __init__(self, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
|
|
|
+ def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
|
|
|
"""This class is able to generate synthetic data for the SIR model.
|
|
|
|
|
|
Args:
|
|
|
+ plotter (Plotter): Plotter object to plot dataset curves.
|
|
|
N (int, optional): Size of the population. Defaults to 59e6.
|
|
|
I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
|
|
|
R_0 (int, optional): Initial size of the removed group. Defaults to 0.
|
|
@@ -95,7 +66,7 @@ class SIR(SyntheticDeseaseData):
|
|
|
self.alpha = alpha
|
|
|
self.beta = beta
|
|
|
|
|
|
- super().__init__(simulation_time, time_points)
|
|
|
+ super().__init__(simulation_time, time_points, plotter)
|
|
|
|
|
|
def differential_eq(self, y, t, alpha, beta):
|
|
|
"""In this function implements the differential equation of the SIR model will be implemented.
|
|
@@ -137,10 +108,11 @@ class SIR(SyntheticDeseaseData):
|
|
|
|
|
|
|
|
|
class SIDR(SyntheticDeseaseData):
|
|
|
- def __init__(self, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
|
|
|
+ def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
|
|
|
"""This class is able to generate synthetic data for the SIDR model.
|
|
|
|
|
|
Args:
|
|
|
+ plotter (Plotter): Plotter object to plot dataset curves.
|
|
|
N (int, optional): Size of the population. Defaults to 59e6.
|
|
|
I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
|
|
|
D_0 (int, optional): Initial size of the dead group. Defaults to 0.
|
|
@@ -161,7 +133,7 @@ class SIDR(SyntheticDeseaseData):
|
|
|
self.beta = beta
|
|
|
self.gamma = gamma
|
|
|
|
|
|
- super().__init__(simulation_time, time_points)
|
|
|
+ super().__init__(simulation_time, time_points, plotter)
|
|
|
|
|
|
def differential_eq(self, y, t, alpha, beta, gamma):
|
|
|
"""In this function implements the differential equation of the SIDR model will be implemented.
|
|
@@ -190,10 +162,10 @@ class SIDR(SyntheticDeseaseData):
|
|
|
self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
|
|
|
super().generate()
|
|
|
|
|
|
- def plot(self):
|
|
|
+ def plot(self, title):
|
|
|
"""Plot the data which was generated.
|
|
|
"""
|
|
|
- super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'))
|
|
|
+ super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'), title=title)
|
|
|
|
|
|
def save(self, name=''):
|
|
|
if self.generated:
|