synthetic_data.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import numpy as np
  2. from scipy.integrate import odeint
  3. import matplotlib.pyplot as plt
  4. from matplotlib import rcParams
  5. FONT_COLOR = '#595959'
  6. SUSCEPTIBLE = '#6399f7'
  7. INFECTIOUS = '#f56262'
  8. REMOVED = '#83eb5e'
  9. rcParams['font.family'] = 'Comfortaa'
  10. rcParams['font.size'] = 12
  11. rcParams['text.color'] = FONT_COLOR
  12. rcParams['axes.labelcolor'] = FONT_COLOR
  13. rcParams['xtick.color'] = FONT_COLOR
  14. rcParams['ytick.color'] = FONT_COLOR
  15. class SyntheticDeseaseData:
  16. def __init__(self, simulation_time, time_points):
  17. """This class is the parent class for every class, that is supposed to generate synthetic data.
  18. Args:
  19. simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
  20. time_points (int): Number of time sample points.
  21. """
  22. self.t = np.linspace(0, simulation_time, time_points)
  23. self.data = None
  24. self.generated = False
  25. def differential_eq(self):
  26. """In this function the differential equation of the model will be implemented.
  27. """
  28. pass
  29. def generate(self):
  30. """In this function the data generation will be implemented.
  31. """
  32. self.generated = True
  33. def plot(self, labels: tuple, title=''):
  34. """Plot the data which was generated.
  35. Args:
  36. labels (tuple): The names of each curve.
  37. """
  38. if self.generated:
  39. fig = plt.figure(figsize=(6,6))
  40. ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
  41. ax.set_facecolor('xkcd:white')
  42. color = (SUSCEPTIBLE, INFECTIOUS, REMOVED, 'red')
  43. for i in range(len(self.data)):
  44. # plot each group
  45. ax.plot(self.t, self.data[i], color[i], lw=3, label=labels[i])
  46. ax.set_xlabel('Time in days')
  47. ax.set_ylabel('Amount of people')
  48. ax.yaxis.set_tick_params(length=0)
  49. ax.xaxis.set_tick_params(length=0)
  50. ax.grid(which='major', c='black', lw=0.2, ls='-')
  51. legend = ax.legend()
  52. legend.get_frame().set_alpha(0.5)
  53. for spine in ('top', 'right', 'bottom', 'left'):
  54. ax.spines[spine].set_visible(False)
  55. if title == '':
  56. plt.savefig('visualizations/synthetic_dataset.png')
  57. else:
  58. plt.savefig('visualizations/' + title + '.png', transparent=True)
  59. else:
  60. print('Data has to be generated before plotting!') # Fabienne war hier
  61. class SIR(SyntheticDeseaseData):
  62. def __init__(self, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
  63. """This class is able to generate synthetic data for the SIR model.
  64. Args:
  65. N (int, optional): Size of the population. Defaults to 59e6.
  66. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  67. R_0 (int, optional): Initial size of the removed group. Defaults to 0.
  68. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  69. time_points (int, optional): Number of time sample points. Defaults to 100.
  70. alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
  71. beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
  72. """
  73. self.N = N
  74. self.S_0 = N - I_0 - R_0
  75. self.I_0 = I_0
  76. self.R_0 = R_0
  77. self.alpha = alpha
  78. self.beta = beta
  79. super().__init__(simulation_time, time_points)
  80. def differential_eq(self, y, t, alpha, beta):
  81. """In this function implements the differential equation of the SIR model will be implemented.
  82. Args:
  83. y (tuple): Vector that holds the current state of the three groups.
  84. t (_): not used
  85. alpha (_): not used
  86. beta (_): not used
  87. Returns:
  88. tuple: Change amount for each group.
  89. """
  90. S, I, R = y
  91. dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
  92. dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
  93. dRdt = self.alpha * I
  94. return dSdt, dIdt, dRdt
  95. def generate(self):
  96. """This funtion generates the data for this configuration of the SIR model.
  97. """
  98. y_0 = self.S_0, self.I_0, self.R_0
  99. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
  100. super().generate()
  101. def plot(self, title=''):
  102. """Plot the data which was generated.
  103. """
  104. super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
  105. def save(self, name=''):
  106. if self.generated:
  107. COVID_Data = np.asarray([self.t, *self.data])
  108. np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
  109. else:
  110. print('Data has to be generated before plotting!')
  111. class SIDR(SyntheticDeseaseData):
  112. def __init__(self, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
  113. """This class is able to generate synthetic data for the SIDR model.
  114. Args:
  115. N (int, optional): Size of the population. Defaults to 59e6.
  116. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  117. D_0 (int, optional): Initial size of the dead group. Defaults to 0.
  118. R_0 (int, optional): Initial size of the recovered group. Defaults to 0.
  119. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  120. time_points (int, optional): Number of time sample points. Defaults to 100.
  121. alpha (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.191.
  122. beta (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Dead'. Defaults to 0.05.
  123. gamma (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Recovered'. Defaults to 0.0294.
  124. """
  125. self.N = N
  126. self.S_0 = N - I_0 - D_0 - R_0
  127. self.I_0 = I_0
  128. self.D_0 = D_0
  129. self.R_0 = R_0
  130. self.alpha = alpha
  131. self.beta = beta
  132. self.gamma = gamma
  133. super().__init__(simulation_time, time_points)
  134. def differential_eq(self, y, t, alpha, beta, gamma):
  135. """In this function implements the differential equation of the SIDR model will be implemented.
  136. Args:
  137. y (tuple): Vector that holds the current state of the three groups.
  138. t (_): not used
  139. alpha (_): not used
  140. beta (_): not used
  141. gamma (_): not used
  142. Returns:
  143. tuple: Change amount for each group.
  144. """
  145. S, I, D, R = y
  146. dSdt = - (self.alpha / self.N) * S * I
  147. dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
  148. dDdt = self.gamma * I
  149. dRdt = self.beta * I
  150. return dSdt, dIdt, dDdt, dRdt
  151. def generate(self):
  152. """This funtion generates the data for this configuration of the SIR model.
  153. """
  154. y_0 = self.S_0, self.I_0, self.D_0, self.R_0
  155. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
  156. super().generate()
  157. def plot(self):
  158. """Plot the data which was generated.
  159. """
  160. super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'))
  161. def save(self, name=''):
  162. if self.generated:
  163. COVID_Data = np.asarray([self.t, *self.data])
  164. np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
  165. else:
  166. print('Data has to be generated before plotting!')