phillip.rothenbeck преди 6 месеца
родител
ревизия
7555f1b41e
променени са 1 файла, в които са добавени 55 реда и са изтрити 80 реда
  1. 55 80
      src/preprocessing/synthetic_data.py

+ 55 - 80
src/preprocessing/synthetic_data.py

@@ -29,7 +29,7 @@ class SyntheticDeseaseData:
         """
         self.generated = True
 
-    def plot(self, labels: tuple, title:str):
+    def plot(self, labels: tuple, title:str, file_name:str):
         """Plot the data which was generated.
 
         Args:
@@ -38,31 +38,31 @@ class SyntheticDeseaseData:
         """
         assert len(labels) == len(self.data), 'The number labels needs to be the same as the number of plots.'
         if self.generated:
-            self.plotter.plot(self.t, self.data, labels, title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            self.plotter.plot(self.t, self.data, labels, file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
         else: 
             print('Data has to be generated before plotting!')
 
-class SI(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
+class SIR(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+        """This class is able to generate synthetic data for the SIR model.
 
         Args:
             plotter (Plotter): Plotter object to plot dataset curves.
             N (int, optional): Size of the population. Defaults to 59e6.
             I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
+            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
             simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
             time_points (int, optional): Number of time sample points. Defaults to 100.
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
-
         self.N = N
-        self.S_0 = N - I_0
+        self.S_0 = N - I_0 - R_0
         self.I_0 = I_0
+        self.R_0 = R_0
 
         self.alpha = alpha
         self.beta = beta
-
         super().__init__(simulation_time, time_points, plotter)
 
     def differential_eq(self, y, t, alpha, beta):
@@ -77,33 +77,34 @@ class SI(SyntheticDeseaseData):
         Returns:
             tuple: Change amount for each group.
         """
-        S, I = y
-        dSdt = -self.beta * ((S * I) / self.N)
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I
-        return dSdt, dIdt
-    
+        S, I, _ = y
+        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
+        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
+        dRdt = self.alpha * I
+        return dSdt, dIdt, dRdt
+
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """
-        y_0 = self.S_0, self.I_0
+        y_0 = self.S_0, self.I_0, self.R_0
         self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
         super().generate()
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name='SIR_plot'):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious'), title=title)
+        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title, file_name=file_name)
 
     def save(self, name=''):
         if self.generated:
             COVID_Data = np.asarray([self.t, *self.data]) 
 
-            np.savetxt('datasets/SI_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
         else: 
             print('Data has to be generated before plotting!')
 
-class I(SI):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
+class I(SyntheticDeseaseData):
+    def __init__(self, plotter:Plotter, N:int, C:int, I_0=1, time_points=100, alpha=1/3) -> None:
         """This class is able to generate synthetic data of the SI groups for the reduced SIR model. This is done by utiling the SIR model.
 
         Args:
@@ -115,94 +116,68 @@ class I(SI):
             alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
             beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
         """
-
-        super().__init__(plotter, N=N, I_0=I_0, simulation_time=simulation_time, time_points=time_points, alpha=alpha, beta=beta)
-
-    def generate(self):
-        """This funtion generates the data for this configuration of the SIR model.
-        """
-        super().generate()
-        self.data = self.data[1]
-        print(self.data.shape)
-
-    def plot(self, title=''):
-        """Plot the data which was generated.
-        """
-        if self.generated:
-            self.plotter.plot(self.t, [self.data], ['Infectious'], title, title, (6, 6), xlabel='time / days', ylabel='amount of people')
-        else: 
-            print('Data has to be generated before plotting!')
-
-    def save(self, name=''):
-        if self.generated:
-            COVID_Data = np.asarray([self.t, self.data]) 
-
-            np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
-        else: 
-            print('Data has to be generated before plotting!')
-
-
-class SIR(SyntheticDeseaseData):
-    def __init__(self, plotter:Plotter, N=59e6, I_0=1, R_0=0, simulation_time=500, time_points=100, alpha=0.191, beta=0.05) -> None:
-        """This class is able to generate synthetic data for the SIR model.
-
-        Args:
-            plotter (Plotter): Plotter object to plot dataset curves.
-            N (int, optional): Size of the population. Defaults to 59e6.
-            I_0 (int, optional): Initial size of the infectious group. Defaults to 1.
-            R_0 (int, optional): Initial size of the removed group. Defaults to 0.
-            simulation_time (int, optional): Real time for that the synthetic data is supposed to be generated in days. Defaults to 500.
-            time_points (int, optional): Number of time sample points. Defaults to 100.
-            alpha (float, optional): Factor dictating how many people per timestep go from 'Infectious' to 'Removed'. Defaults to 0.191.
-            beta (float, optional): Factor dictating how many people per timestep go from 'Susceptible' to 'Infectious'. Defaults to 0.05.
-        """
         self.N = N
-        self.S_0 = N - I_0 - R_0
+        self.C = C
         self.I_0 = I_0
-        self.R_0 = R_0
-
+ 
         self.alpha = alpha
-        self.beta = beta
 
-        super().__init__(simulation_time, time_points, plotter)
+        self.t = np.linspace(0, 1, time_points)
+        self.t_save = np.linspace(1, time_points, time_points)
+        self.t_f = time_points
+        self.reproduction_value = []
+        self.data = None
+        self.generated = False
+        self.plotter = plotter
+        
+    def R_t(self, t):
+        descaled_t = t * self.t_f
+        # if descaled_t < threshold1:
+        return -np.tanh(descaled_t * 0.05 - 2) * 0.4 + 1.35
 
-    def differential_eq(self, y, t, alpha, beta):
+
+            
+    def differential_eq(self, I, t):
         """In this function implements the differential equation of the SIR model will be implemented.
 
         Args:
             y (tuple): Vector that holds the current state of the three groups.
             t (_): not used
-            alpha (_): not used
-            beta (_): not used
 
         Returns:
             tuple: Change amount for each group.
         """
-        S, I, R = y
-        dSdt = -self.beta * ((S * I) / self.N) # -self.beta * S * I
-        dIdt = self.beta * ((S * I) / self.N) - self.alpha * I # self.beta * S * I - self.alpha * I
-        dRdt = self.alpha * I
-        return dSdt, dIdt, dRdt
+        dIdt = self.alpha * self.t_f * (self.R_t(t) - 1) * I
+        return dIdt
 
     def generate(self):
         """This funtion generates the data for this configuration of the SIR model.
         """
-        y_0 = self.S_0, self.I_0, self.R_0
-        self.data = odeint(self.differential_eq, y_0, self.t, args=(self.alpha, self.beta)).T
-        super().generate()
+        self.data = odeint(self.differential_eq, self.I_0/self.C, self.t).T
+        self.data = self.data[0] * self.C
+        self.t_counter = 0
+        self.generated =True
 
-    def plot(self, title=''):
+    def plot(self, title='', file_name=''):
         """Plot the data which was generated.
         """
-        super().plot(('Susceptible', 'Infectious', 'Removed'), title=title)
+        if self.generated:
+            t = np.linspace(0, len(self.t), len(self.t))
+            self.plotter.plot(t, [self.data], ['Infectious'], file_name, title, (6, 6), xlabel='time / days', ylabel='amount of people')
+            for time in self.t:
+                self.reproduction_value.append(self.R_t(time))
+            self.plotter.plot(t, [np.array(self.reproduction_value)], [r'$\mathcal{R}_t$'], file_name + '_r_t', title + r' $\mathcal{R}_t$', (6, 6), xlabel='time / days')
+        else: 
+            print('Data has to be generated before plotting!')
 
     def save(self, name=''):
         if self.generated:
-            COVID_Data = np.asarray([self.t, *self.data]) 
+            COVID_Data = np.asarray([self.t_save, self.data]) 
 
-            np.savetxt('datasets/SIR_data.csv', COVID_Data, delimiter=",")
+            np.savetxt('datasets/I_data.csv', COVID_Data, delimiter=",")
         else: 
             print('Data has to be generated before plotting!')
+
         
 
 class SIDR(SyntheticDeseaseData):