Browse Source

add cuda support + add SIRProblem

phillip.rothenbeck 1 year ago
parent
commit
0798c031c4
1 changed files with 43 additions and 6 deletions
  1. 43 6
      src/problem.py

+ 43 - 6
src/problem.py

@@ -8,17 +8,54 @@ class PandemicProblem:
         Args:
             data (PandemicDataset): Dataset holding the time values used.
         """
+
+        self.data = data
+        self.device_name = data.device_name
+
         #store the gradients for each group
-        self.gradients = [torch.zeros((len(data.t_raw), data.number_groups)) for _ in range(data.number_groups)]
+        self.gradients = [torch.zeros((len(data.t_raw), data.number_groups), device=self.device_name) for _ in range(data.number_groups)]
 
         for i in range(data.number_groups):
             self.gradients[i][:, i] = 1
 
-    def to_device(self, device):
-        for i in range(len(self.gradients)):
-            self.gradients[i] = self.gradients[i].to(device)
+    def residual(self):
+        """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
+        """
+        pass
 
-    def residual():
+    def denormalization(self):
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
-        pass
+        pass
+
+class SIRProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset):
+        super().__init__(data)
+
+    def residual(self, SIR_pred, alpha, beta):
+        SIR_pred.backward(self.gradients[0], retain_graph=True)
+        dSdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self.gradients[1], retain_graph=True)
+        dIdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
+
+        SIR_pred.backward(self.gradients[2], retain_graph=True)
+        dRdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
+        
+        S, I, R = self.denormalization(SIR_pred)
+
+        S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))
+        R_residual = dRdt - (alpha * I) / (self.data.get_max('R') - self.data.get_min('R'))
+
+        return S_residual, I_residual, R_residual
+    
+    def denormalization(self, predictions):
+        S_pred, I_pred, R_pred = predictions[:, 0], predictions[:, 1], predictions[:, 2]
+        S = self.data.get_min('S') + (self.data.get_max('S') - self.data.get_min('S')) * S_pred
+        I = self.data.get_min('I') + (self.data.get_max('I') - self.data.get_min('I')) * I_pred
+        R = self.data.get_min('R') + (self.data.get_max('R') - self.data.get_min('R')) * R_pred
+        return S, I, R