phillip.rothenbeck 5 сар өмнө
parent
commit
c97674cbe1
1 өөрчлөгдсөн 32 нэмэгдсэн , 4 устгасан
  1. 32 4
      src/problem.py

+ 32 - 4
src/problem.py

@@ -1,6 +1,7 @@
 import torch
 import torch
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 
 
+
 class PandemicProblem:
 class PandemicProblem:
     def __init__(self, data: PandemicDataset) -> None:
     def __init__(self, data: PandemicDataset) -> None:
         """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
         """Parent class for all pandemic problem classes. Holding the function, that calculates the residuals of the differential system.
@@ -18,14 +19,14 @@ class PandemicProblem:
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
         """
         assert self._gradients != None, 'Gradientmatrix need to be defined'
         assert self._gradients != None, 'Gradientmatrix need to be defined'
-        
 
 
-    def def_grad_matrix(self, number:int):
+    def def_grad_matrix(self, number: int):
         assert self._gradients == None, 'Gradientmatrix is already defined'
         assert self._gradients == None, 'Gradientmatrix is already defined'
         self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
         self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
         for i in range(number):
         for i in range(number):
             self._gradients[i][:, i] = 1
             self._gradients[i][:, i] = 1
 
 
+
 class SIRProblem(PandemicProblem):
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
         super().__init__(data)
@@ -53,8 +54,36 @@ class SIRProblem(PandemicProblem):
         return S_residual, I_residual, R_residual
         return S_residual, I_residual, R_residual
 
 
 
 
+class SIRAlphaProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset, alpha):
+        super().__init__(data)
+        self.alpha = alpha
+
+    def residual(self, SIR_pred, beta):
+        super().residual()
+        SIR_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self._gradients[2], retain_graph=True)
+        dRdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+
+        S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self._data.N) - self.alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
+        R_residual = dRdt - (self.alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
+
+        return S_residual, I_residual, R_residual
+
+
 class ReducedSIRProblem(PandemicProblem):
 class ReducedSIRProblem(PandemicProblem):
-    def __init__(self, data: PandemicDataset, alpha:float):
+    def __init__(self, data: PandemicDataset, alpha: float):
         super().__init__(data)
         super().__init__(data)
         self.alpha = alpha
         self.alpha = alpha
 
 
@@ -72,4 +101,3 @@ class ReducedSIRProblem(PandemicProblem):
 
 
         I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         return I_residual
         return I_residual
-