|
@@ -58,17 +58,18 @@ class ReducedSIRProblem(PandemicProblem):
|
|
|
super().__init__(data)
|
|
|
self.alpha = alpha
|
|
|
|
|
|
- def residual(self, SI_pred):
|
|
|
+ def residual(self, I_pred):
|
|
|
super().residual()
|
|
|
|
|
|
- SI_pred.backward(self._gradients[0], retain_graph=True)
|
|
|
- dIdt = self._data.t_raw.grad.clone()
|
|
|
- self._data.t_raw.grad.zero_()
|
|
|
-
|
|
|
- I = SI_pred[:, 0]
|
|
|
- R_t = SI_pred[:, 1]
|
|
|
+ I_pred.backward(self._gradients[0], retain_graph=True)
|
|
|
+ dIdt = self._data.t_scaled.grad.clone()
|
|
|
+ self._data.t_scaled.grad.zero_()
|
|
|
|
|
|
- I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
|
|
|
+ I = I_pred[:, 0]
|
|
|
+ R_t = I_pred[:, 1]
|
|
|
+
|
|
|
+ # dIdt = torch.autograd.grad(I, self._data.t_scaled, torch.ones_like(I), create_graph=True)[0]
|
|
|
|
|
|
+ I_residual = dIdt - (self.alpha * (self._data.t_final - self._data.t_init) * (R_t - 1) * I)
|
|
|
return I_residual
|
|
|
|