Переглянути джерело

add backward grad method and move zero grad

phillip.rothenbeck 1 рік тому
батько
коміт
c6fac50a96
1 змінених файлів з 19 додано та 6 видалено
  1. 19 6
      src/sir_dinn/dinn.py

+ 19 - 6
src/sir_dinn/dinn.py

@@ -32,6 +32,10 @@ class DINN:
 
         self.epochs = None
 
+        self.S_grads = torch.zeros((len(self.data.t_raw), 3)); self.S_grads[:, 0] = 1
+        self.I_grads = torch.zeros((len(self.data.t_raw), 3)); self.I_grads[:, 1] = 1
+        self.R_grads = torch.zeros((len(self.data.t_raw), 3)); self.R_grads[:, 2] = 1
+
     @property
     def alpha(self):
         return torch.tanh(self.alpha_tilda)
@@ -51,14 +55,21 @@ class DINN:
 
         self.epochs = epochs
         for epoch in range(epochs):
-            optimizer.zero_grad()
 
-            SIR_pred = self.pinn(self.data.t)
+            SIR_pred = self.pinn(self.data.t_batch)
             S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
             
-            dSdt = torch.autograd.grad(S_pred, self.data.t, torch.zeros_like(S_pred), create_graph=True)[0]
-            dIdt = torch.autograd.grad(I_pred, self.data.t, torch.zeros_like(I_pred), create_graph=True)[0]
-            dRdt = torch.autograd.grad(R_pred, self.data.t, torch.zeros_like(R_pred), create_graph=True)[0]
+            SIR_pred.backward(self.S_grads, retain_graph=True)
+            dSdt = self.data.t_raw.grad.clone()
+            self.data.t_raw.grad.zero_()
+
+            SIR_pred.backward(self.I_grads, retain_graph=True)
+            dIdt = self.data.t_raw.grad.clone()
+            self.data.t_raw.grad.zero_()
+
+            SIR_pred.backward(self.R_grads, retain_graph=True)
+            dRdt = self.data.t_raw.grad.clone()
+            self.data.t_raw.grad.zero_()
             
             S = self.data.S_min + (self.data.S_max - self.data.S_min) * S_pred
             I = self.data.I_min + (self.data.I_max - self.data.I_min) * I_pred
@@ -68,6 +79,8 @@ class DINN:
             I_residual = dIdt - (self.beta * ((S * I) / self.data.N) - self.alpha * I) / (self.data.I_max - self.data.I_min)
             R_residual = dRdt - (self.alpha * I) / (self.data.R_max - self.data.R_min)
 
+            optimizer.zero_grad()
+
             loss_physics = (torch.mean(torch.square(S_residual)) + 
                             torch.mean(torch.square(I_residual)) +
                             torch.mean(torch.square(R_residual)))
@@ -104,7 +117,7 @@ class DINN:
         plt.plot(epochs, np.ones(self.epochs) * 0.333, label='true alpha')
         plt.legend()
         plt.show()
-        plt.plot(epochs, self.betas, label='parameter bate')
+        plt.plot(epochs, self.betas, label='parameter beta')
         plt.plot(epochs, np.ones(self.epochs) * 0.5, label='true beta')
         plt.legend()
         plt.show()