|
|
@@ -1,132 +0,0 @@
|
|
|
-import torch
|
|
|
-import numpy as np
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-
|
|
|
-from .dataset.dataset import SIR_Dataset
|
|
|
-
|
|
|
-class DINN:
|
|
|
- class PINN(torch.nn.Module):
|
|
|
- def __init__(self) -> None:
|
|
|
- super(DINN.PINN, self).__init__()
|
|
|
-
|
|
|
- self.input = torch.nn.Sequential(torch.nn.Linear(1, 20), torch.nn.ReLU())
|
|
|
- self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(20, 20), torch.nn.ReLU()) for _ in range(7)])
|
|
|
- self.output = torch.nn.Linear(20, 3)
|
|
|
-
|
|
|
- def forward(self, t):
|
|
|
- x = self.input(t)
|
|
|
- x = self.hidden(x)
|
|
|
- x = self.output(x)
|
|
|
- return x
|
|
|
-
|
|
|
- def __init__(self, data: SIR_Dataset):
|
|
|
- self.data = data
|
|
|
- self.pinn = DINN.PINN()
|
|
|
-
|
|
|
- self.alpha_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
|
|
|
- self.beta_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
|
|
|
-
|
|
|
- self.losses = np.zeros(1)
|
|
|
- self.alphas = np.zeros(1)
|
|
|
- self.betas = np.zeros(1)
|
|
|
-
|
|
|
- self.epochs = None
|
|
|
-
|
|
|
- self.S_grads = torch.zeros((len(self.data.t_raw), 3)); self.S_grads[:, 0] = 1
|
|
|
- self.I_grads = torch.zeros((len(self.data.t_raw), 3)); self.I_grads[:, 1] = 1
|
|
|
- self.R_grads = torch.zeros((len(self.data.t_raw), 3)); self.R_grads[:, 2] = 1
|
|
|
-
|
|
|
- @property
|
|
|
- def alpha(self):
|
|
|
- return torch.tanh(self.alpha_tilda)
|
|
|
-
|
|
|
- @property
|
|
|
- def beta(self):
|
|
|
- return torch.tanh(self.beta_tilda)
|
|
|
-
|
|
|
- def train(self, epochs, lr):
|
|
|
- optimizer = torch.optim.Adam(list(self.pinn.parameters()) + [self.alpha_tilda, self.beta_tilda], lr=lr)
|
|
|
- scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
-
|
|
|
- # arrays for plotting
|
|
|
- self.losses = np.zeros(epochs)
|
|
|
- self.alphas = np.zeros(epochs)
|
|
|
- self.betas = np.zeros(epochs)
|
|
|
-
|
|
|
- self.epochs = epochs
|
|
|
- for epoch in range(epochs):
|
|
|
-
|
|
|
- SIR_pred = self.pinn(self.data.t_batch)
|
|
|
- S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
|
|
|
-
|
|
|
- SIR_pred.backward(self.S_grads, retain_graph=True)
|
|
|
- dSdt = self.data.t_raw.grad.clone()
|
|
|
- self.data.t_raw.grad.zero_()
|
|
|
-
|
|
|
- SIR_pred.backward(self.I_grads, retain_graph=True)
|
|
|
- dIdt = self.data.t_raw.grad.clone()
|
|
|
- self.data.t_raw.grad.zero_()
|
|
|
-
|
|
|
- SIR_pred.backward(self.R_grads, retain_graph=True)
|
|
|
- dRdt = self.data.t_raw.grad.clone()
|
|
|
- self.data.t_raw.grad.zero_()
|
|
|
-
|
|
|
- S = self.data.S_min + (self.data.S_max - self.data.S_min) * S_pred
|
|
|
- I = self.data.I_min + (self.data.I_max - self.data.I_min) * I_pred
|
|
|
- R = self.data.R_min + (self.data.R_max - self.data.R_min) * R_pred
|
|
|
-
|
|
|
- S_residual = dSdt - (-self.beta * ((S * I) / self.data.N)) / (self.data.S_max - self.data.S_min)
|
|
|
- I_residual = dIdt - (self.beta * ((S * I) / self.data.N) - self.alpha * I) / (self.data.I_max - self.data.I_min)
|
|
|
- R_residual = dRdt - (self.alpha * I) / (self.data.R_max - self.data.R_min)
|
|
|
-
|
|
|
- optimizer.zero_grad()
|
|
|
-
|
|
|
- loss_physics = (torch.mean(torch.square(S_residual)) +
|
|
|
- torch.mean(torch.square(I_residual)) +
|
|
|
- torch.mean(torch.square(R_residual)))
|
|
|
- loss_obs = (torch.mean(torch.square(self.data.S_norm - S_pred)) +
|
|
|
- torch.mean(torch.square(self.data.I_norm - I_pred)) +
|
|
|
- torch.mean(torch.square(self.data.R_norm - R_pred)))
|
|
|
- loss = loss_obs + loss_physics
|
|
|
-
|
|
|
- loss.backward()
|
|
|
- optimizer.step()
|
|
|
- scheduler.step()
|
|
|
-
|
|
|
- self.losses[epoch] = loss.item()
|
|
|
- self.alphas[epoch] = self.alpha.item()
|
|
|
- self.betas[epoch] = self.beta.item()
|
|
|
-
|
|
|
- if epoch % 1000 == 0:
|
|
|
- print('\nEpoch ', epoch)
|
|
|
-
|
|
|
- print(f'alpha:\t\t\tgoal|trained 0.333|{self.alpha.item()}')
|
|
|
- print(f'beta:\t\t\tgoal|trained 0.5|{self.beta.item()}')
|
|
|
- print('---------------------------------')
|
|
|
- print(f'physics loss:\t\t{loss_physics.item()}')
|
|
|
- print(f'observation loss:\t{loss_obs.item()}')
|
|
|
- print(f'loss:\t\t\t{loss.item()}')
|
|
|
-
|
|
|
- print('#################################')
|
|
|
-
|
|
|
- def plot(self):
|
|
|
- assert self.epochs != None
|
|
|
- epochs = np.arange(0, self.epochs, 1)
|
|
|
-
|
|
|
- plt.plot(epochs, self.alphas, label='parameter alpha')
|
|
|
- plt.plot(epochs, np.ones(self.epochs) * 0.333, label='true alpha')
|
|
|
- plt.legend()
|
|
|
- plt.show()
|
|
|
- plt.plot(epochs, self.betas, label='parameter beta')
|
|
|
- plt.plot(epochs, np.ones(self.epochs) * 0.5, label='true beta')
|
|
|
- plt.legend()
|
|
|
- plt.show()
|
|
|
-
|
|
|
- plt.plot(epochs, self.losses)
|
|
|
- plt.title('Loss')
|
|
|
- plt.yscale('log')
|
|
|
- plt.show()
|
|
|
-
|
|
|
- def predict(self, t):
|
|
|
- t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
|
|
|
- return self.pinn(t)
|