Ver Fonte

working sir dinn code

phillip.rothenbeck há 1 ano atrás
pai
commit
2ab14ef2e2
3 ficheiros alterados com 664 adições e 0 exclusões
  1. 23 0
      src/sir_dinn/dataset/dataset.py
  2. 119 0
      src/sir_dinn/dinn.py
  3. 522 0
      synth_sir_dinn.ipynb

+ 23 - 0
src/sir_dinn/dataset/dataset.py

@@ -0,0 +1,23 @@
+import torch
+
+class SIR_Dataset:
+    def __init__(self, N, t, S, I, R):
+        self.N = N
+        self.t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
+        print(torch.min(self.t), torch.max(self.t))
+
+        self.S = torch.tensor(S)
+        self.I = torch.tensor(I)
+        self.R = torch.tensor(R)
+
+        self.S_min = torch.min(self.S)
+        self.I_min = torch.min(self.I)
+        self.R_min = torch.min(self.R)
+
+        self.S_max = torch.max(self.S)
+        self.I_max = torch.max(self.I)
+        self.R_max = torch.max(self.R)
+
+        self.S_norm = (self.S - self.S_min) / (self.S_max - self.S_min)
+        self.I_norm = (self.I - self.I_min) / (self.I_max - self.I_min)
+        self.R_norm = (self.R - self.R_min) / (self.R_max - self.R_min)

+ 119 - 0
src/sir_dinn/dinn.py

@@ -0,0 +1,119 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+from .dataset.dataset import SIR_Dataset
+
+class DINN:
+    class PINN(torch.nn.Module):
+        def __init__(self) -> None:
+            super(DINN.PINN, self).__init__()
+
+            self.input = torch.nn.Sequential(torch.nn.Linear(1, 20), torch.nn.ReLU())
+            self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(20, 20), torch.nn.ReLU()) for _ in range(7)])
+            self.output = torch.nn.Linear(20, 3)
+
+        def forward(self, t):
+            x = self.input(t)
+            x = self.hidden(x)
+            x = self.output(x)
+            return x
+        
+    def __init__(self, data: SIR_Dataset):
+        self.data = data
+        self.pinn = DINN.PINN()
+
+        self.alpha_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
+        self.beta_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
+
+        self.losses = np.zeros(1)
+        self.alphas = np.zeros(1)
+        self.betas = np.zeros(1)
+
+        self.epochs = None
+
+    @property
+    def alpha(self):
+        return torch.tanh(self.alpha_tilda)
+    
+    @property
+    def beta(self):
+        return torch.tanh(self.beta_tilda)
+    
+    def train(self, epochs, lr):
+        optimizer = torch.optim.Adam(list(self.pinn.parameters()) + [self.alpha_tilda, self.beta_tilda], lr=lr)
+        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+
+        # arrays for plotting
+        self.losses = np.zeros(epochs)
+        self.alphas = np.zeros(epochs)
+        self.betas = np.zeros(epochs)
+
+        self.epochs = epochs
+        for epoch in range(epochs):
+            optimizer.zero_grad()
+
+            SIR_pred = self.pinn(self.data.t)
+            S_pred, I_pred, R_pred = SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]
+            
+            dSdt = torch.autograd.grad(S_pred, self.data.t, torch.zeros_like(S_pred), create_graph=True)[0]
+            dIdt = torch.autograd.grad(I_pred, self.data.t, torch.zeros_like(I_pred), create_graph=True)[0]
+            dRdt = torch.autograd.grad(R_pred, self.data.t, torch.zeros_like(R_pred), create_graph=True)[0]
+            
+            S = self.data.S_min + (self.data.S_max - self.data.S_min) * S_pred
+            I = self.data.I_min + (self.data.I_max - self.data.I_min) * I_pred
+            R = self.data.R_min + (self.data.R_max - self.data.R_min) * R_pred
+            
+            S_residual = dSdt - (-self.beta * ((S * I) / self.data.N)) / (self.data.S_max - self.data.S_min)
+            I_residual = dIdt - (self.beta * ((S * I) / self.data.N) - self.alpha * I) / (self.data.I_max - self.data.I_min)
+            R_residual = dRdt - (self.alpha * I) / (self.data.R_max - self.data.R_min)
+
+            loss_physics = (torch.mean(torch.square(S_residual)) + 
+                            torch.mean(torch.square(I_residual)) +
+                            torch.mean(torch.square(R_residual)))
+            loss_obs = (torch.mean(torch.square(self.data.S_norm - S_pred)) + 
+                        torch.mean(torch.square(self.data.I_norm - I_pred)) + 
+                        torch.mean(torch.square(self.data.R_norm - R_pred)))
+            loss = loss_obs + loss_physics
+
+            loss.backward()
+            optimizer.step()
+            scheduler.step()
+
+            self.losses[epoch] = loss.item()
+            self.alphas[epoch] = self.alpha.item()
+            self.betas[epoch] = self.beta.item()
+
+            if epoch % 1000 == 0:          
+                print('\nEpoch ', epoch)
+
+                print(f'alpha:\t\t\tgoal|trained 0.333|{self.alpha.item()}')
+                print(f'beta:\t\t\tgoal|trained 0.5|{self.beta.item()}')
+                print('---------------------------------')
+                print(f'physics loss:\t\t{loss_physics.item()}')
+                print(f'observation loss:\t{loss_obs.item()}')
+                print(f'loss:\t\t\t{loss.item()}')
+
+                print('#################################') 
+
+    def plot(self):
+        assert self.epochs != None
+        epochs = np.arange(0, self.epochs, 1)
+        
+        plt.plot(epochs, self.alphas, label='parameter alpha')
+        plt.plot(epochs, np.ones(self.epochs) * 0.333, label='true alpha')
+        plt.legend()
+        plt.show()
+        plt.plot(epochs, self.betas, label='parameter bate')
+        plt.plot(epochs, np.ones(self.epochs) * 0.5, label='true beta')
+        plt.legend()
+        plt.show()
+
+        plt.plot(epochs, self.losses)
+        plt.title('Loss')
+        plt.yscale('log')
+        plt.show()
+
+    def predict(self, t):
+        t = torch.tensor(t, requires_grad=True).view(-1, 1).float()
+        return self.pinn(t)

Diff do ficheiro suprimidas por serem muito extensas
+ 522 - 0
synth_sir_dinn.ipynb


Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff