ソースを参照

clean up reduced

phillip.rothenbeck 4 ヶ月 前
コミット
5f34dd8418
1 ファイル変更9 行追加8 行削除
  1. 9 8
      src/problem.py

+ 9 - 8
src/problem.py

@@ -58,17 +58,18 @@ class ReducedSIRProblem(PandemicProblem):
         super().__init__(data)
         self.alpha = alpha
 
-    def residual(self, SI_pred):
+    def residual(self, I_pred):
         super().residual()
 
-        SI_pred.backward(self._gradients[0], retain_graph=True)
-        dIdt = self._data.t_raw.grad.clone()
-        self._data.t_raw.grad.zero_()
-        
-        I = SI_pred[:, 0]
-        R_t = SI_pred[:, 1]
+        I_pred.backward(self._gradients[0], retain_graph=True)
+        dIdt = self._data.t_scaled.grad.clone()
+        self._data.t_scaled.grad.zero_()
 
-        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
+        I = I_pred[:, 0]
+        R_t = I_pred[:, 1]
+
+        # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0]
 
+        I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
         return I_residual