Kaynağa Gözat

adding generalized dinn class

phillip.rothenbeck 1 yıl önce
ebeveyn
işleme
3ef34d9e7e
1 değiştirilmiş dosya ile 197 ekleme ve 0 silme
  1. 197 0
      src/dinn.py

+ 197 - 0
src/dinn.py

@@ -0,0 +1,197 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+from .dataset import PandemicDataset
+from .problem import PandemicProblem
+
+class DINN:
+    class NN(torch.nn.Module):
+        def __init__(self, 
+                     output_size: int,
+                     input_size: int,
+                     hidden_size: int,
+                     hidden_layers: int, 
+                     activation_layer) -> None:
+            """Neural Network
+
+            Args:
+                output_size (int): number of outputs
+                input_size (int): number of inputs
+                hidden_size (int): number of hidden nodes per layer
+                hidden_layers (int): number of hidden layers
+                activation_layer (_type_): activation layer
+            """
+            super(DINN.NN, self).__init__()
+
+            self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
+            self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
+            self.output = torch.nn.Linear(hidden_size, output_size)
+
+        def forward(self, t):
+            x = self.input(t)
+            x = self.hidden(x)
+            x = self.output(x)
+            return x
+
+    def __init__(self,
+                 number_groups: int,
+                 data: PandemicDataset,
+                 parameter_list: list,
+                 problem: PandemicProblem,
+                 parameter_regulator=torch.tanh,
+                 input_size=1,
+                 hidden_size=20,
+                 hidden_layers=7, 
+                 activation_layer=torch.nn.ReLU()) -> None:
+        """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
+        parameters of a specific mathematical model.
+
+        Args:
+            number_groups (int): The number of groups, that the population is split into.
+            data (PandemicDataset): Data collected showing the course of the pandemic
+            parameter_list (list): List of the parameter names(strings), that are supposed to be found.
+            problem (PandemicProblem): Problem class implementing the calculation of the residuals.
+            parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
+            input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
+            hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
+            hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
+            activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
+        """
+        
+        self.device = torch.device('cpu')
+
+        self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
+        self.data = data
+        self.parameter_regulator = parameter_regulator
+        self.problem = problem
+
+        self.parameters_tilda = {}
+        for parameter in parameter_list:
+            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True))})
+        
+        self.epochs = None
+
+        self.losses = np.zeros(1)
+        self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
+
+    def get_regulated_param(self, parameter_name: str):
+        """Function to get the searched parameters, forced into a certain range.
+
+        Args:
+            parameter_name (str): Name of the parameter to be returned.
+
+        Returns:
+            torch.Parameter: Regulated parameter object of the search parameter.
+        """
+        return self.parameter_regulator(self.parameters_tilda[parameter_name])
+    
+    def get_parameters_tilda(self):
+        """Function to get the original value (not forced into any range).
+
+        Returns:
+            torch.Parameter: Parameter object of the search parameter.
+        """
+        return list(self.parameters_tilda.values())
+
+    def get_regulated_param_list(self):
+        """Get the list of regulated parameters (forced into a specific range).
+
+        Returns:
+            list: list of regulated parameters
+        """
+        return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
+    
+    def train(self, 
+              epochs: int, 
+              lr: float, 
+              optimizer_class=torch.optim.Adam):
+        """Training routine for the DINN
+
+        Args:
+            epochs (int): Number of epochs the NN is supposed to be trained for.
+            lr (float): Learning rate for the optimizer.
+            optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
+        """
+
+        # define optimizer and scheduler
+        optimizer = optimizer_class(list(self.model.parameters()) + list(self.parameters_tilda.values()), lr=lr)
+        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+
+        self.epochs = epochs
+
+        # arrays to hold values for plotting
+        self.losses = np.zeros(epochs)
+        self.parameters = [np.zeros(epochs) for _ in self.parameters]
+
+        for epoch in range(epochs):
+            # get the prediction and the fitting residuals
+            prediction = self.model(self.data.t_batch)
+            residuals = self.problem.residual(prediction, self.data, *self.get_regulated_param_list())
+
+            optimizer.zero_grad()
+
+            # calculate loss from the differential system
+            loss_physics = 0
+            for residual in residuals:
+                loss_physics += torch.mean(torch.square(residual))
+
+            # calculate loss from the dataset
+            loss_obs = 0
+            for i, group in enumerate(self.data.get_group_names()):
+                loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+            
+            loss = loss_physics + loss_obs
+
+            loss.backward()
+            optimizer.step()
+            scheduler.step()
+
+            # append values for plotting
+            self.losses[epoch] = loss.item()
+            for i, parameter in enumerate(self.parameters_tilda.items()):
+                self.parameters[i][epoch] = parameter[1].item()
+
+            # print training advancements
+            if epoch % 1000 == 0:          
+                print('\nEpoch ', epoch)
+                print(f'physics loss:\t\t{loss_physics.item()}')
+                print(f'observation loss:\t{loss_obs.item()}')
+                print(f'loss:\t\t\t{loss.item()}')
+                print('---------------------------------')
+                for parameter in self.parameters_tilda.items():
+                    print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
+                print('#################################') 
+
+    def plot_training_graphs(self):
+        """Plot the loss graph and the graphs of the advancements of the parameters.
+        """
+        assert self.epochs != None
+        epochs = np.arange(0, self.epochs, 1)
+
+        # plot loss
+        plt.plot(epochs, self.losses)
+        plt.title('Loss')
+        plt.yscale('log')
+        plt.show()
+
+        # plot parameters
+        for i, parameter in enumerate(self.parameters):
+            plt.plot(epochs, parameter)
+            plt.title(list(self.parameters_tilda.items())[i][0])
+            plt.show()
+
+    def to_cuda(self):
+        """Move training to cuda device
+        """
+        assert torch.cuda.is_available(), "CUDA is not available"
+        
+        self.device = torch.device('cuda')
+        self.model = self.model.to(self.device)
+        self.data.to_device(self.device)
+        self.problem.to_device(self.device)
+
+        for parameter in self.parameters_tilda:
+            self.parameters_tilda[parameter] = self.parameters_tilda[parameter].to(self.device).detach().requires_grad_()
+        
+