فهرست منبع

add optimizer/scheduler configuation possibility + changes for modified SIR

phillip.rothenbeck 1 سال پیش
والد
کامیت
0ab0e0ce57
2فایلهای تغییر یافته به همراه184 افزوده شده و 75 حذف شده
  1. 131 54
      src/dinn.py
  2. 53 21
      src/problem.py

+ 131 - 54
src/dinn.py

@@ -1,16 +1,10 @@
 import torch
-import os
-import imageio
 import numpy as np
-import matplotlib.pyplot as plt
-from matplotlib import rcParams
 
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .plotter import Plotter
 
-
-
 class DINN:
     class NN(torch.nn.Module):
         def __init__(self, 
@@ -48,55 +42,68 @@ class DINN:
             return x
 
     def __init__(self,
-                 number_groups: int,
+                 output_size: int,
                  data: PandemicDataset,
                  parameter_list: list,
                  problem: PandemicProblem,
                  plotter: Plotter,
+                 state_variables=[],
                  parameter_regulator=torch.tanh,
-                 input_size=1,
-                 hidden_size=20,
+                 input_size=1, 
+                 hidden_size=20, 
                  hidden_layers=7, 
                  activation_layer=torch.nn.ReLU()) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
 
         Args:
-            number_groups (int): The number of groups, that the population is split into.
+            output_size (int): Number of the output nodes of the NN.
             data (PandemicDataset): Data collected showing the course of the pandemic
             parameter_list (list): List of the parameter names(strings), that are supposed to be found.
             problem (PandemicProblem): Problem class implementing the calculation of the residuals.
             plotter (Plotter): Plotter object to plot dataset curves.
+            state_variables (list, optional): List of the names of state variables. Defaults to [].
             parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
             input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
             hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
             hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
         """
-        
+        assert len(state_variables) + data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
         self.device = torch.device(data.device_name)
         self.device_name = data.device_name
         self.plotter = plotter
 
-        self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer, data.t_init, data.t_final)
+        self.model = DINN.NN(output_size, 
+                             input_size, 
+                             hidden_size, 
+                             hidden_layers, 
+                             activation_layer, 
+                             data.t_init, 
+                             data.t_final)
         self.model = self.model.to(self.device)
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.problem = problem
+        self.problem.def_grad_matrix(output_size)
 
         self.parameters_tilda = {}
         for parameter in parameter_list:
             self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
-        
-        self.epochs = None
 
-        self.losses = np.zeros(1)
-        self.obs_losses = np.zeros(1)
-        self.physics_losses = np.zeros(1)
+        # new model has to be configured and then trained
+        self.__is_configured = False
+        self.__has_trained = False
+
+        self.__state_variables = state_variables
+
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
 
         self.frames = []
 
+    @property
+    def number_state_variables(self):
+        return len(self.__state_variables)
 
     def get_regulated_param(self, parameter_name: str):
         """Function to get the searched parameters, forced into a certain range.
@@ -125,40 +132,77 @@ class DINN:
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
     
+    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+        """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
+
+        Args:
+            lr (float): Learning rate for the optimizer.
+            epochs (int): Number of epochs the NN is supposed to be trained for.
+            optimizer_name (str, optional): Name of the optimizer class that is supposed to be used. Defaults to 'Adam'.
+            scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
+            verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
+        """
+        parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
+        self.epochs = epochs
+        match optimizer_name:
+            case 'Adam':
+                self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
+            case _:
+                self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
+                if verbose:
+                    print('---------------------------------')
+                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print('---------------------------------')
+                optimizer_name = 'Adam'
+
+        match scheduler_name:
+            case 'CyclicLR':
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+            case 'LinearLR':
+                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
+            case 'PolynomialLR':
+                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
+            case _:
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+                if verbose:
+                    print('---------------------------------')
+                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print('---------------------------------')
+                scheduler_name = 'CyclicLR'
+
+        if verbose:
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+
+        self.__is_configured = True
+
+    
     def train(self, 
-              epochs: int, 
-              lr: float, 
-              optimizer_class=torch.optim.Adam,
               create_animation=False,
-              animation_sample_rate=500):
-        """Training routine for the DINN
+              animation_sample_rate=500,
+              verbose=False):
+        """Training routine for the DINN.
 
         Args:
-            epochs (int): Number of epochs the NN is supposed to be trained for.
-            lr (float): Learning rate for the optimizer.
-            optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
             create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
             animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
+            verbose (bool, optional): Controles if the training process, is to be verbosed. Defaults to False.
         """
-
-        # define optimizer and scheduler
-        optimizer = optimizer_class(list(self.model.parameters()) + list(self.parameters_tilda.values()), lr=lr)
-        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-
-        self.epochs = epochs
-
+        assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
+        if verbose:
+            print(f'torch seed: {torch.seed()}')
+        
         # arrays to hold values for plotting
-        self.losses = np.zeros(epochs)
-        self.obs_losses = np.zeros(epochs)
-        self.physics_losses = np.zeros(epochs)
-        self.parameters = [np.zeros(epochs) for _ in self.parameters]
+        self.losses = np.zeros(self.epochs)
+        self.obs_losses = np.zeros(self.epochs)
+        self.physics_losses = np.zeros(self.epochs)
+        self.parameters = [np.zeros(self.epochs) for _ in self.parameters]
 
-        for epoch in range(epochs):
+        for epoch in range(self.epochs):
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
 
-            optimizer.zero_grad()
+            self.optimizer.zero_grad()
 
             # calculate loss from the differential system
             loss_physics = 0
@@ -170,11 +214,11 @@ class DINN:
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
             
-            loss = loss_physics + loss_obs
+            loss = loss_obs + loss_physics
 
             loss.backward()
-            optimizer.step()
-            scheduler.step()
+            self.optimizer.step()
+            self.scheduler.step()
 
             # append values for plotting
             self.losses[epoch] = loss.item()
@@ -188,15 +232,18 @@ class DINN:
                 # prediction
                 prediction = self.model(self.data.t_batch)
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-                groups = self.data.get_denormalized_data([prediction[:, 0], prediction[:, 1], prediction[:, 2]])
+                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+                plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
+                background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
                 self.plotter.plot(t, 
                                   list(groups) + list(self.data.data), 
-                                  [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names],
+                                  plot_labels,
                                   'frame',
                                   f'epoch {epoch}',
                                   figure_shape=(12, 6),
                                   is_frame=True,
-                                  is_background=[0, 0, 0, 1, 1, 1],
+                                  is_background=background_list,
                                   lw=3,
                                   legend_loc='upper right',
                                   ylim=(0, self.data.N), 
@@ -204,19 +251,23 @@ class DINN:
                                   ylabel='amount of people')
 
             # print training advancements
-            if epoch % 1000 == 0:          
-                print('\nEpoch ', epoch)
+            if epoch % 1000 == 0 and verbose:          
+                print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
                 print(f'loss:\t\t\t{loss.item()}')
                 print('---------------------------------')
-                for parameter in self.parameters_tilda.items():
-                    print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
-                print('#################################')
+                if len(self.parameters_tilda.items()) != 0:
+                    for parameter in self.parameters_tilda.items():
+                        print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
+                    print('#################################')
 
         # create prediction animation
         if create_animation:
             self.plotter.animate(self.data.name + '_animation')
+            self.plotter.reset_animation()
+
+        self.__has_trained = True
 
     def plot_training_graphs(self, ground_truth=[]):
         """Plot the loss graph and the graphs of the advancements of the parameters.
@@ -224,7 +275,7 @@ class DINN:
         Args:
             ground_truth (list): List of the ground truth parameters
         """
-        assert self.epochs != None
+        assert self.__has_trained, 'Model has to be trained, before plotting the training graphs'
         epochs = np.arange(0, self.epochs, 1)
 
         # plot loss
@@ -233,8 +284,34 @@ class DINN:
         # plot parameters
         for i, parameter in enumerate(self.parameters):
             if len(ground_truth) > i:
-                self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
+                self.plotter.plot(epochs, 
+                                  [parameter, 
+                                   np.ones_like(epochs) * ground_truth[i]], 
+                                   ['prediction', 'ground truth'], 
+                                   self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
+                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   is_background=[0, 1], 
+                                   xlabel='epochs')
             else:
-                self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
-
-        
+                self.plotter.plot(epochs, 
+                                  [parameter], 
+                                  ['prediction'], 
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
+                                  list(self.parameters_tilda.items())[i][0], (6,6), 
+                                  xlabel='epochs', 
+                                  plot_legend=False)
+
+    def plot_state_variables(self):
+        for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
+            prediction = self.model(self.data.t_batch)
+            groups = [prediction[:, i] for i in range(self.data.number_groups)]
+            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            self.plotter.plot(t,
+                              [prediction[:, i]] + groups,
+                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
+                              self.__state_variables[i-self.data.number_groups],
+                              is_background=[0, 1, 1],
+                              figure_shape=(12, 6),
+                              plot_legend=True,
+                              xlabel='time / days')

+ 53 - 21
src/problem.py

@@ -9,41 +9,73 @@ class PandemicProblem:
             data (PandemicDataset): Dataset holding the time values used.
         """
 
-        self.data = data
-        self.device_name = data.device_name
+        self._data = data
+        self._device_name = data.device_name
 
-        #store the gradients for each group
-        self.gradients = [torch.zeros((len(data.t_raw), data.number_groups), device=self.device_name) for _ in range(data.number_groups)]
-
-        for i in range(data.number_groups):
-            self.gradients[i][:, i] = 1
+        self._gradients = None
 
     def residual(self):
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
-        pass
+        assert self._gradients != None, 'Gradientmatrix need to be defined'
+        
+
+    def def_grad_matrix(self, number:int):
+        assert self._gradients == None, 'Gradientmatrix is already defined'
+        self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
+        for i in range(number):
+            self._gradients[i][:, i] = 1
 
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
 
     def residual(self, SIR_pred, alpha, beta):
-        SIR_pred.backward(self.gradients[0], retain_graph=True)
-        dSdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        super().residual()
+        SIR_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
-        SIR_pred.backward(self.gradients[1], retain_graph=True)
-        dIdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        SIR_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
-        SIR_pred.backward(self.gradients[2], retain_graph=True)
-        dRdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        SIR_pred.backward(self._gradients[2], retain_graph=True)
+        dRdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
-        S, I, R = self.data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+        S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
 
-        S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
-        I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))
-        R_residual = dRdt - (alpha * I) / (self.data.get_max('R') - self.data.get_min('R'))
+        S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
+        R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
 
         return S_residual, I_residual, R_residual
+
+
+class ReducedSIRProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset, alpha:float):
+        super().__init__(data)
+        self.alpha = alpha
+
+    def residual(self, SI_pred):
+        super().residual()
+        SI_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SI_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]])
+        R_t = SI_pred[:, 2]
+        # I = SI_pred[:, 1]
+
+        S_residual = dSdt - (-self.alpha * R_t * I)
+        I_residual = dIdt - (self.alpha * (R_t - 1) * I)
+
+        # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}')
+
+        return S_residual, I_residual
+