|
|
@@ -0,0 +1,119 @@
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
+from .dataset.dataset import SIR_Dataset
|
|
|
+
|
|
|
+class DINN:
|
|
|
+ class PINN(torch.nn.Module):
|
|
|
+ def __init__(self) -> None:
|
|
|
+ super(DINN.PINN, self).__init__()
|
|
|
+
|
|
|
+ self.input = torch.nn.Sequential(torch.nn.Linear(1, 20), torch.nn.ReLU())
|
|
|
+ self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(20, 20), torch.nn.ReLU()) for _ in range(7)])
|
|
|
+ self.output = torch.nn.Linear(20, 3)
|
|
|
+
|
|
|
+ def forward(self, t):
|
|
|
+ x = self.input(t)
|
|
|
+ x = self.hidden(x)
|
|
|
+ x = self.output(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+ def __init__(self, data: SIR_Dataset):
|
|
|
+ self.data = data
|
|
|
+ self.pinn = DINN.PINN()
|
|
|
+
|
|
|
+ self.alpha_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
|
|
|
+ self.beta_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
|
|
|
+
|
|
|
+ self.losses = np.zeros(1)
|
|
|
+ self.alphas = np.zeros(1)
|
|
|
+ self.betas = np.zeros(1)
|
|
|
+
|
|
|
+ self.epochs = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def alpha(self):
|
|
|
+ return torch.tanh(self.alpha_tilda)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def beta(self):
|
|
|
+ return torch.tanh(self.beta_tilda)
|
|
|
+
|
|
|
+ def train(self, epochs, lr):
|
|
|
+ optimizer = torch.optim.Adam(list(self.pinn.parameters()) + [self.alpha_tilda, self.beta_tilda], lr=lr)
|
|
|
+ scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+
|
|
|
+ # arrays for plotting
|
|
|
+ self.losses = np.zeros(epochs)
|
|
|
+ self.alphas = np.zeros(epochs)
|
|
|
+ self.betas = np.zeros(epochs)
|
|
|
+
|
|
|
+ self.epochs = epochs
|
|
|
+ for epoch in range(epochs):
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ SIR_pred = self.pinn(self.data.t)
|
|
|
+ S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
|
|
|
+
|
|
|
+ dSdt = torch.autograd.grad(S_pred, self.data.t, torch.zeros_like(S_pred), create_graph=True)[0]
|
|
|
+ dIdt = torch.autograd.grad(I_pred, self.data.t, torch.zeros_like(I_pred), create_graph=True)[0]
|
|
|
+ dRdt = torch.autograd.grad(R_pred, self.data.t, torch.zeros_like(R_pred), create_graph=True)[0]
|
|
|
+
|
|
|
+ S = self.data.S_min + (self.data.S_max - self.data.S_min) * S_pred
|
|
|
+ I = self.data.I_min + (self.data.I_max - self.data.I_min) * I_pred
|
|
|
+ R = self.data.R_min + (self.data.R_max - self.data.R_min) * R_pred
|
|
|
+
|
|
|
+ S_residual = dSdt - (-self.beta * ((S * I) / self.data.N)) / (self.data.S_max - self.data.S_min)
|
|
|
+ I_residual = dIdt - (self.beta * ((S * I) / self.data.N) - self.alpha * I) / (self.data.I_max - self.data.I_min)
|
|
|
+ R_residual = dRdt - (self.alpha * I) / (self.data.R_max - self.data.R_min)
|
|
|
+
|
|
|
+ loss_physics = (torch.mean(torch.square(S_residual)) +
|
|
|
+ torch.mean(torch.square(I_residual)) +
|
|
|
+ torch.mean(torch.square(R_residual)))
|
|
|
+ loss_obs = (torch.mean(torch.square(self.data.S_norm - S_pred)) +
|
|
|
+ torch.mean(torch.square(self.data.I_norm - I_pred)) +
|
|
|
+ torch.mean(torch.square(self.data.R_norm - R_pred)))
|
|
|
+ loss = loss_obs + loss_physics
|
|
|
+
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ self.losses[epoch] = loss.item()
|
|
|
+ self.alphas[epoch] = self.alpha.item()
|
|
|
+ self.betas[epoch] = self.beta.item()
|
|
|
+
|
|
|
+ if epoch % 1000 == 0:
|
|
|
+ print('\nEpoch ', epoch)
|
|
|
+
|
|
|
+ print(f'alpha:\t\t\tgoal|trained 0.333|{self.alpha.item()}')
|
|
|
+ print(f'beta:\t\t\tgoal|trained 0.5|{self.beta.item()}')
|
|
|
+ print('---------------------------------')
|
|
|
+ print(f'physics loss:\t\t{loss_physics.item()}')
|
|
|
+ print(f'observation loss:\t{loss_obs.item()}')
|
|
|
+ print(f'loss:\t\t\t{loss.item()}')
|
|
|
+
|
|
|
+ print('#################################')
|
|
|
+
|
|
|
+ def plot(self):
|
|
|
+ assert self.epochs != None
|
|
|
+ epochs = np.arange(0, self.epochs, 1)
|
|
|
+
|
|
|
+ plt.plot(epochs, self.alphas, label='parameter alpha')
|
|
|
+ plt.plot(epochs, np.ones(self.epochs) * 0.333, label='true alpha')
|
|
|
+ plt.legend()
|
|
|
+ plt.show()
|
|
|
+ plt.plot(epochs, self.betas, label='parameter bate')
|
|
|
+ plt.plot(epochs, np.ones(self.epochs) * 0.5, label='true beta')
|
|
|
+ plt.legend()
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ plt.plot(epochs, self.losses)
|
|
|
+ plt.title('Loss')
|
|
|
+ plt.yscale('log')
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ def predict(self, t):
|
|
|
+ t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
|
|
|
+ return self.pinn(t)
|