synthetic_data.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import numpy as np
  2. from scipy.integrate import odeint
  3. import matplotlib.pyplot as plt
  4. class SyntheticDeseaseData:
  5. def __init__(self, simulation_time, time_points):
  6. """This class is the parent class for every class, that is supposed to generate synthetic data.
  7. Args:
  8. simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
  9. time_points (int): Number of time sample points.
  10. """
  11. self.t = np.linspace(0, simulation_time, time_points)
  12. self.data = None
  13. self.generated = False
  14. def differential_eq(self):
  15. """In this function the differential equation of the model will be implemented.
  16. """
  17. pass
  18. def generate(self):
  19. """In this function the data generation will be implemented.
  20. """
  21. self.generated = True
  22. def plot(self, labels: tuple):
  23. """Plot the data which was generated.
  24. Args:
  25. labels (tuple): The names of each curve.
  26. """
  27. if self.generated:
  28. fig = plt.figure(figsize=(12,12))
  29. ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
  30. ax.set_facecolor('xkcd:white')
  31. color = ('violet', 'darkgreen', 'blue', 'red')
  32. for i in range(len(self.data)):
  33. # plot each group
  34. ax.plot(self.t, self.data[i], color[i], alpha=0.5, lw=2, label=labels[i], linestyle='dashed')
  35. ax.set_xlabel('Time per days')
  36. ax.set_ylabel('Number')
  37. ax.yaxis.set_tick_params(length=0)
  38. ax.xaxis.set_tick_params(length=0)
  39. ax.grid(which='major', c='black', lw=0.2, ls='-')
  40. legend = ax.legend()
  41. legend.get_frame().set_alpha(0.5)
  42. for spine in ('top', 'right', 'bottom', 'left'):
  43. ax.spines[spine].set_visible(False)
  44. plt.savefig('visualizations/synthetic_dataset.png')
  45. else:
  46. print('Data has to be generated before plotting!') # Fabienne war hier
  47. class SIR(SyntheticDeseaseData):
  48. def __init__(self, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
  49. """This class is able to generate synthetic data for the SIR model.
  50. Args:
  51. N (int, optional): Size of the population. Defaults to 59e6.
  52. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  53. R_0 (int, optional): Initial size of the removed group. Defaults to 0.
  54. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  55. time_points (int, optional): Number of time sample points. Defaults to 100.
  56. alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
  57. beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
  58. """
  59. self.N = N
  60. self.S_0 = N - I_0 - R_0
  61. self.I_0 = I_0
  62. self.R_0 = R_0
  63. self.alpha = alpha
  64. self.beta = beta
  65. super().__init__(simulation_time, time_points)
  66. def differential_eq(self, y, t, alpha, beta):
  67. """In this function implements the differential equation of the SIR model will be implemented.
  68. Args:
  69. y (tuple): Vector that holds the current state of the three groups.
  70. t (_): not used
  71. alpha (_): not used
  72. beta (_): not used
  73. Returns:
  74. tuple: Change amount for each group.
  75. """
  76. S, I, R = y
  77. dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
  78. dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
  79. dRdt = self.alpha * I
  80. return dSdt, dIdt, dRdt
  81. def generate(self):
  82. """This funtion generates the data for this configuration of the SIR model.
  83. """
  84. y_0 = self.S_0, self.I_0, self.R_0
  85. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
  86. super().generate()
  87. def plot(self):
  88. """Plot the data which was generated.
  89. """
  90. super().plot(('Susceptible', 'Infectious', 'Removed'))
  91. def save(self, name=''):
  92. if self.generated:
  93. COVID_Data = np.asarray([self.t, *self.data])
  94. np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
  95. else:
  96. print('Data has to be generated before plotting!')
  97. class SIDR(SyntheticDeseaseData):
  98. def __init__(self, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
  99. """This class is able to generate synthetic data for the SIDR model.
  100. Args:
  101. N (int, optional): Size of the population. Defaults to 59e6.
  102. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  103. D_0 (int, optional): Initial size of the dead group. Defaults to 0.
  104. R_0 (int, optional): Initial size of the recovered group. Defaults to 0.
  105. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  106. time_points (int, optional): Number of time sample points. Defaults to 100.
  107. alpha (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.191.
  108. beta (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Dead'. Defaults to 0.05.
  109. gamma (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Recovered'. Defaults to 0.0294.
  110. """
  111. self.N = N
  112. self.S_0 = N - I_0 - D_0 - R_0
  113. self.I_0 = I_0
  114. self.D_0 = D_0
  115. self.R_0 = R_0
  116. self.alpha = alpha
  117. self.beta = beta
  118. self.gamma = gamma
  119. super().__init__(simulation_time, time_points)
  120. def differential_eq(self, y, t, alpha, beta, gamma):
  121. """In this function implements the differential equation of the SIDR model will be implemented.
  122. Args:
  123. y (tuple): Vector that holds the current state of the three groups.
  124. t (_): not used
  125. alpha (_): not used
  126. beta (_): not used
  127. gamma (_): not used
  128. Returns:
  129. tuple: Change amount for each group.
  130. """
  131. S, I, D, R = y
  132. dSdt = - (self.alpha / self.N) * S * I
  133. dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
  134. dDdt = self.gamma * I
  135. dRdt = self.beta * I
  136. return dSdt, dIdt, dDdt, dRdt
  137. def generate(self):
  138. """This funtion generates the data for this configuration of the SIR model.
  139. """
  140. y_0 = self.S_0, self.I_0, self.D_0, self.R_0
  141. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
  142. super().generate()
  143. def plot(self):
  144. """Plot the data which was generated.
  145. """
  146. super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'))
  147. def save(self, name=''):
  148. if self.generated:
  149. COVID_Data = np.asarray([self.t, *self.data])
  150. np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
  151. else:
  152. print('Data has to be generated before plotting!')