Browse Source

revert gradient calculation

phillip.rothenbeck 1 năm trước cách đây
mục cha
commit
8aee177872
1 tập tin đã thay đổi với 7 bổ sung17 xóa
  1. 7 17
      src/problem.py

+ 7 - 17
src/problem.py

@@ -28,29 +28,19 @@ class SIRProblem(PandemicProblem):
         super().__init__(data)
 
     def residual(self, SIR_pred, alpha, beta):
-        S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
-
-        # dSdt = torch.autograd.grad(S_pred, self.data.t_raw, torch.ones_like(S_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[0], retain_graph=True)
-        dSdt_norm = self.data.t_norm.grad.clone()
-        self.data.t_norm.grad.zero_()
+        dSdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
 
-        # dIdt = torch.autograd.grad(I_pred, self.data.t_raw, torch.ones_like(I_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[1], retain_graph=True)
-        dIdt_norm = self.data.t_norm.grad.clone()
-        self.data.t_norm.grad.zero_()
+        dIdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
 
-        # dRdt = torch.autograd.grad(R_pred, self.data.t_raw, torch.ones_like(R_pred), create_graph=True)[0]
         SIR_pred.backward(self.gradients[2], retain_graph=True)
-        dRdt_norm = self.data.t_norm.grad.clone()
-        self.data.t_norm.grad.zero_()
-        
+        dRdt = self.data.t_raw.grad.clone()
+        self.data.t_raw.grad.zero_()
+
         S, I, R = self.data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
-        # print(f'dSdt: {dSdt}, dIdt: {dIdt}, dRdt: {dRdt}')
-        
-        dSdt = dSdt_norm * self.data.normalization_differantial
-        dIdt = dRdt_norm * self.data.normalization_differantial
-        dRdt = dIdt_norm * self.data.normalization_differantial
 
         S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
         I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))