synthetic_data.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import numpy as np
  2. from scipy.integrate import odeint
  3. from src.plotter import Plotter
  4. class SyntheticDeseaseData:
  5. def __init__(self, simulation_time:int, time_points:int, plotter:Plotter):
  6. """This class is the parent class for every class, that is supposed to generate synthetic data.
  7. Args:
  8. simulation_time (int): Real time for that the synthetic data is supposed to be generated in days.
  9. time_points (int): Number of time sample points.
  10. plotter (Plotter): Plotter object to plot dataset curves.
  11. """
  12. self.t = np.linspace(0, simulation_time, time_points)
  13. self.data = None
  14. self.generated = False
  15. self.plotter = plotter
  16. def differential_eq(self):
  17. """In this function the differential equation of the model will be implemented.
  18. """
  19. pass
  20. def generate(self):
  21. """In this function the data generation will be implemented.
  22. """
  23. self.generated = True
  24. def plot(self, labels: tuple, title:str):
  25. """Plot the data which was generated.
  26. Args:
  27. labels (tuple): The names of each curve.
  28. title (str): The name of the plot.
  29. """
  30. assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
  31. if self.generated:
  32. self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
  33. else:
  34. print('Data has to be generated before plotting!') # Fabienne war hier
  35. class SIR(SyntheticDeseaseData):
  36. def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
  37. """This class is able to generate synthetic data for the SIR model.
  38. Args:
  39. plotter (Plotter): Plotter object to plot dataset curves.
  40. N (int, optional): Size of the population. Defaults to 59e6.
  41. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  42. R_0 (int, optional): Initial size of the removed group. Defaults to 0.
  43. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  44. time_points (int, optional): Number of time sample points. Defaults to 100.
  45. alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
  46. beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
  47. """
  48. self.N = N
  49. self.S_0 = N - I_0 - R_0
  50. self.I_0 = I_0
  51. self.R_0 = R_0
  52. self.alpha = alpha
  53. self.beta = beta
  54. super().__init__(simulation_time, time_points, plotter)
  55. def differential_eq(self, y, t, alpha, beta):
  56. """In this function implements the differential equation of the SIR model will be implemented.
  57. Args:
  58. y (tuple): Vector that holds the current state of the three groups.
  59. t (_): not used
  60. alpha (_): not used
  61. beta (_): not used
  62. Returns:
  63. tuple: Change amount for each group.
  64. """
  65. S, I, R = y
  66. dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
  67. dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
  68. dRdt = self.alpha * I
  69. return dSdt, dIdt, dRdt
  70. def generate(self):
  71. """This funtion generates the data for this configuration of the SIR model.
  72. """
  73. y_0 = self.S_0, self.I_0, self.R_0
  74. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
  75. super().generate()
  76. def plot(self, title=''):
  77. """Plot the data which was generated.
  78. """
  79. super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
  80. def save(self, name=''):
  81. if self.generated:
  82. COVID_Data = np.asarray([self.t, *self.data])
  83. np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
  84. else:
  85. print('Data has to be generated before plotting!')
  86. class SIDR(SyntheticDeseaseData):
  87. def __init__(self, plotter:Plotter, N=59e6, I_0=1, D_0=0, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05, gamma=0.0294) -> None:
  88. """This class is able to generate synthetic data for the SIDR model.
  89. Args:
  90. plotter (Plotter): Plotter object to plot dataset curves.
  91. N (int, optional): Size of the population. Defaults to 59e6.
  92. I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
  93. D_0 (int, optional): Initial size of the dead group. Defaults to 0.
  94. R_0 (int, optional): Initial size of the recovered group. Defaults to 0.
  95. simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
  96. time_points (int, optional): Number of time sample points. Defaults to 100.
  97. alpha (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.191.
  98. beta (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Dead'. Defaults to 0.05.
  99. gamma (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Recovered'. Defaults to 0.0294.
  100. """
  101. self.N = N
  102. self.S_0 = N - I_0 - D_0 - R_0
  103. self.I_0 = I_0
  104. self.D_0 = D_0
  105. self.R_0 = R_0
  106. self.alpha = alpha
  107. self.beta = beta
  108. self.gamma = gamma
  109. super().__init__(simulation_time, time_points, plotter)
  110. def differential_eq(self, y, t, alpha, beta, gamma):
  111. """In this function implements the differential equation of the SIDR model will be implemented.
  112. Args:
  113. y (tuple): Vector that holds the current state of the three groups.
  114. t (_): not used
  115. alpha (_): not used
  116. beta (_): not used
  117. gamma (_): not used
  118. Returns:
  119. tuple: Change amount for each group.
  120. """
  121. S, I, D, R = y
  122. dSdt = - (self.alpha / self.N) * S * I
  123. dIdt = (self.alpha / self.N) * S * I - self.beta * I - self.gamma * I
  124. dDdt = self.gamma * I
  125. dRdt = self.beta * I
  126. return dSdt, dIdt, dDdt, dRdt
  127. def generate(self):
  128. """This funtion generates the data for this configuration of the SIR model.
  129. """
  130. y_0 = self.S_0, self.I_0, self.D_0, self.R_0
  131. self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta, self.gamma)).T
  132. super().generate()
  133. def plot(self, title):
  134. """Plot the data which was generated.
  135. """
  136. super().plot(('Susceptible', 'Infectious', 'Dead', 'Recovered'), title=title)
  137. def save(self, name=''):
  138. if self.generated:
  139. COVID_Data = np.asarray([self.t, *self.data])
  140. np.savetxt('datasets/SIDR_data.csv', COVID_Data, delimiter=",")
  141. else:
  142. print('Data has to be generated before plotting!')