|
|
@@ -32,6 +32,10 @@ class DINN:
|
|
|
|
|
|
self.epochs = None
|
|
|
|
|
|
+ self.S_grads = torch.zeros((len(self.data.t_raw), 3)); self.S_grads[:, 0] = 1
|
|
|
+ self.I_grads = torch.zeros((len(self.data.t_raw), 3)); self.I_grads[:, 1] = 1
|
|
|
+ self.R_grads = torch.zeros((len(self.data.t_raw), 3)); self.R_grads[:, 2] = 1
|
|
|
+
|
|
|
@property
|
|
|
def alpha(self):
|
|
|
return torch.tanh(self.alpha_tilda)
|
|
|
@@ -51,14 +55,21 @@ class DINN:
|
|
|
|
|
|
self.epochs = epochs
|
|
|
for epoch in range(epochs):
|
|
|
- optimizer.zero_grad()
|
|
|
|
|
|
- SIR_pred = self.pinn(self.data.t)
|
|
|
+ SIR_pred = self.pinn(self.data.t_batch)
|
|
|
S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
|
|
|
|
|
|
- dSdt = torch.autograd.grad(S_pred, self.data.t, torch.zeros_like(S_pred), create_graph=True)[0]
|
|
|
- dIdt = torch.autograd.grad(I_pred, self.data.t, torch.zeros_like(I_pred), create_graph=True)[0]
|
|
|
- dRdt = torch.autograd.grad(R_pred, self.data.t, torch.zeros_like(R_pred), create_graph=True)[0]
|
|
|
+ SIR_pred.backward(self.S_grads, retain_graph=True)
|
|
|
+ dSdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self.I_grads, retain_graph=True)
|
|
|
+ dIdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
+
|
|
|
+ SIR_pred.backward(self.R_grads, retain_graph=True)
|
|
|
+ dRdt = self.data.t_raw.grad.clone()
|
|
|
+ self.data.t_raw.grad.zero_()
|
|
|
|
|
|
S = self.data.S_min + (self.data.S_max - self.data.S_min) * S_pred
|
|
|
I = self.data.I_min + (self.data.I_max - self.data.I_min) * I_pred
|
|
|
@@ -68,6 +79,8 @@ class DINN:
|
|
|
I_residual = dIdt - (self.beta * ((S * I) / self.data.N) - self.alpha * I) / (self.data.I_max - self.data.I_min)
|
|
|
R_residual = dRdt - (self.alpha * I) / (self.data.R_max - self.data.R_min)
|
|
|
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
loss_physics = (torch.mean(torch.square(S_residual)) +
|
|
|
torch.mean(torch.square(I_residual)) +
|
|
|
torch.mean(torch.square(R_residual)))
|
|
|
@@ -104,7 +117,7 @@ class DINN:
|
|
|
plt.plot(epochs, np.ones(self.epochs) * 0.333, label='true alpha')
|
|
|
plt.legend()
|
|
|
plt.show()
|
|
|
- plt.plot(epochs, self.betas, label='parameter bate')
|
|
|
+ plt.plot(epochs, self.betas, label='parameter beta')
|
|
|
plt.plot(epochs, np.ones(self.epochs) * 0.5, label='true beta')
|
|
|
plt.legend()
|
|
|
plt.show()
|