|
|
@@ -2,6 +2,21 @@
|
|
|
import numpy as np
|
|
|
from scipy.integrate import odeint
|
|
|
import matplotlib.pyplot as plt
|
|
|
+from matplotlib import rcParams
|
|
|
+
|
|
|
+FONT_COLOR = '#595959'
|
|
|
+SUSCEPTIBLE = '#6399f7'
|
|
|
+INFECTIOUS = '#f56262'
|
|
|
+REMOVED = '#83eb5e'
|
|
|
+
|
|
|
+rcParams['font.family'] = 'Comfortaa'
|
|
|
+rcParams['font.size'] = 12
|
|
|
+
|
|
|
+rcParams['text.color'] = FONT_COLOR
|
|
|
+rcParams['axes.labelcolor'] = FONT_COLOR
|
|
|
+rcParams['xtick.color'] = FONT_COLOR
|
|
|
+rcParams['ytick.color'] = FONT_COLOR
|
|
|
+
|
|
|
|
|
|
class SyntheticDeseaseData:
|
|
|
def __init__(self, simulation_time, time_points):
|
|
|
@@ -25,24 +40,24 @@ class SyntheticDeseaseData:
|
|
|
"""
|
|
|
self.generated = True
|
|
|
|
|
|
- def plot(self, labels: tuple):
|
|
|
+ def plot(self, labels: tuple, title=''):
|
|
|
"""Plot the data which was generated.
|
|
|
|
|
|
Args:
|
|
|
labels (tuple): The names of each curve.
|
|
|
"""
|
|
|
if self.generated:
|
|
|
- fig = plt.figure(figsize=(12,12))
|
|
|
+ fig = plt.figure(figsize=(6,6))
|
|
|
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
ax.set_facecolor('xkcd:white')
|
|
|
|
|
|
- color = ('violet', 'darkgreen', 'blue', 'red')
|
|
|
+ color = (SUSCEPTIBLE, INFECTIOUS, REMOVED, 'red')
|
|
|
for i in range(len(self.data)):
|
|
|
# plot each group
|
|
|
- ax.plot(self.t, self.data[i], color[i], alpha=0.5, lw=2, label=labels[i], linestyle='dashed')
|
|
|
+ ax.plot(self.t, self.data[i], color[i], lw=3, label=labels[i])
|
|
|
|
|
|
- ax.set_xlabel('Time per days')
|
|
|
- ax.set_ylabel('Number')
|
|
|
+ ax.set_xlabel('Time in days')
|
|
|
+ ax.set_ylabel('Amount of people')
|
|
|
ax.yaxis.set_tick_params(length=0)
|
|
|
ax.xaxis.set_tick_params(length=0)
|
|
|
ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
@@ -50,7 +65,10 @@ class SyntheticDeseaseData:
|
|
|
legend.get_frame().set_alpha(0.5)
|
|
|
for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
ax.spines[spine].set_visible(False)
|
|
|
- plt.savefig('visualizations/synthetic_dataset.png')
|
|
|
+ if title == '':
|
|
|
+ plt.savefig('visualizations/synthetic_dataset.png')
|
|
|
+ else:
|
|
|
+ plt.savefig('visualizations/' + title + '.png', transparent=True)
|
|
|
else:
|
|
|
print('Data has to be generated before plotting!') # Fabienne war hier
|
|
|
|
|
|
@@ -104,10 +122,10 @@ class SIR(SyntheticDeseaseData):
|
|
|
self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
|
|
|
super().generate()
|
|
|
|
|
|
- def plot(self):
|
|
|
+ def plot(self, title=''):
|
|
|
"""Plot the data which was generated.
|
|
|
"""
|
|
|
- super().plot(('Susceptible', 'Infectious', 'Removed'))
|
|
|
+ super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
|
|
|
|
|
|
def save(self, name=''):
|
|
|
if self.generated:
|