Browse Source

add optimizer/scheduler configuation possibility + changes for modified SIR

phillip.rothenbeck 1 year ago
parent
commit
0ab0e0ce57
2 changed files with 184 additions and 75 deletions
  1. 131 54
      src/dinn.py
  2. 53 21
      src/problem.py

+ 131 - 54
src/dinn.py

@@ -1,16 +1,10 @@
 import torch
 import torch
-import os
-import imageio
 import numpy as np
 import numpy as np
-import matplotlib.pyplot as plt
-from matplotlib import rcParams
 
 
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .problem import PandemicProblem
 from .plotter import Plotter
 from .plotter import Plotter
 
 
-
-
 class DINN:
 class DINN:
     class NN(torch.nn.Module):
     class NN(torch.nn.Module):
         def __init__(self, 
         def __init__(self, 
@@ -48,55 +42,68 @@ class DINN:
             return x
             return x
 
 
     def __init__(self,
     def __init__(self,
-                 number_groups: int,
+                 output_size: int,
                  data: PandemicDataset,
                  data: PandemicDataset,
                  parameter_list: list,
                  parameter_list: list,
                  problem: PandemicProblem,
                  problem: PandemicProblem,
                  plotter: Plotter,
                  plotter: Plotter,
+                 state_variables=[],
                  parameter_regulator=torch.tanh,
                  parameter_regulator=torch.tanh,
-                 input_size=1,
-                 hidden_size=20,
+                 input_size=1, 
+                 hidden_size=20, 
                  hidden_layers=7, 
                  hidden_layers=7, 
                  activation_layer=torch.nn.ReLU()) -> None:
                  activation_layer=torch.nn.ReLU()) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
         parameters of a specific mathematical model.
 
 
         Args:
         Args:
-            number_groups (int): The number of groups, that the population is split into.
+            output_size (int): Number of the output nodes of the NN.
             data (PandemicDataset): Data collected showing the course of the pandemic
             data (PandemicDataset): Data collected showing the course of the pandemic
             parameter_list (list): List of the parameter names(strings), that are supposed to be found.
             parameter_list (list): List of the parameter names(strings), that are supposed to be found.
             problem (PandemicProblem): Problem class implementing the calculation of the residuals.
             problem (PandemicProblem): Problem class implementing the calculation of the residuals.
             plotter (Plotter): Plotter object to plot dataset curves.
             plotter (Plotter): Plotter object to plot dataset curves.
+            state_variables (list, optional): List of the names of state variables. Defaults to [].
             parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
             parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
             input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
             input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
             hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
             hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
             hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
             hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
         """
         """
-        
+        assert len(state_variables) + data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
         self.device = torch.device(data.device_name)
         self.device = torch.device(data.device_name)
         self.device_name = data.device_name
         self.device_name = data.device_name
         self.plotter = plotter
         self.plotter = plotter
 
 
-        self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer, data.t_init, data.t_final)
+        self.model = DINN.NN(output_size, 
+                             input_size, 
+                             hidden_size, 
+                             hidden_layers, 
+                             activation_layer, 
+                             data.t_init, 
+                             data.t_final)
         self.model = self.model.to(self.device)
         self.model = self.model.to(self.device)
         self.data = data
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.parameter_regulator = parameter_regulator
         self.problem = problem
         self.problem = problem
+        self.problem.def_grad_matrix(output_size)
 
 
         self.parameters_tilda = {}
         self.parameters_tilda = {}
         for parameter in parameter_list:
         for parameter in parameter_list:
             self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
             self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
-        
-        self.epochs = None
 
 
-        self.losses = np.zeros(1)
-        self.obs_losses = np.zeros(1)
-        self.physics_losses = np.zeros(1)
+        # new model has to be configured and then trained
+        self.__is_configured = False
+        self.__has_trained = False
+
+        self.__state_variables = state_variables
+
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
 
 
         self.frames = []
         self.frames = []
 
 
+    @property
+    def number_state_variables(self):
+        return len(self.__state_variables)
 
 
     def get_regulated_param(self, parameter_name: str):
     def get_regulated_param(self, parameter_name: str):
         """Function to get the searched parameters, forced into a certain range.
         """Function to get the searched parameters, forced into a certain range.
@@ -125,40 +132,77 @@ class DINN:
         """
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
     
     
+    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+        """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
+
+        Args:
+            lr (float): Learning rate for the optimizer.
+            epochs (int): Number of epochs the NN is supposed to be trained for.
+            optimizer_name (str, optional): Name of the optimizer class that is supposed to be used. Defaults to 'Adam'.
+            scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
+            verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
+        """
+        parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
+        self.epochs = epochs
+        match optimizer_name:
+            case 'Adam':
+                self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
+            case _:
+                self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
+                if verbose:
+                    print('---------------------------------')
+                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print('---------------------------------')
+                optimizer_name = 'Adam'
+
+        match scheduler_name:
+            case 'CyclicLR':
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+            case 'LinearLR':
+                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
+            case 'PolynomialLR':
+                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
+            case _:
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+                if verbose:
+                    print('---------------------------------')
+                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print('---------------------------------')
+                scheduler_name = 'CyclicLR'
+
+        if verbose:
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+
+        self.__is_configured = True
+
+    
     def train(self, 
     def train(self, 
-              epochs: int, 
-              lr: float, 
-              optimizer_class=torch.optim.Adam,
               create_animation=False,
               create_animation=False,
-              animation_sample_rate=500):
-        """Training routine for the DINN
+              animation_sample_rate=500,
+              verbose=False):
+        """Training routine for the DINN.
 
 
         Args:
         Args:
-            epochs (int): Number of epochs the NN is supposed to be trained for.
-            lr (float): Learning rate for the optimizer.
-            optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
             create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
             create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
             animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
             animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
+            verbose (bool, optional): Controles if the training process, is to be verbosed. Defaults to False.
         """
         """
-
-        # define optimizer and scheduler
-        optimizer = optimizer_class(list(self.model.parameters()) + list(self.parameters_tilda.values()), lr=lr)
-        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-
-        self.epochs = epochs
-
+        assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
+        if verbose:
+            print(f'torch seed: {torch.seed()}')
+        
         # arrays to hold values for plotting
         # arrays to hold values for plotting
-        self.losses = np.zeros(epochs)
-        self.obs_losses = np.zeros(epochs)
-        self.physics_losses = np.zeros(epochs)
-        self.parameters = [np.zeros(epochs) for _ in self.parameters]
+        self.losses = np.zeros(self.epochs)
+        self.obs_losses = np.zeros(self.epochs)
+        self.physics_losses = np.zeros(self.epochs)
+        self.parameters = [np.zeros(self.epochs) for _ in self.parameters]
 
 
-        for epoch in range(epochs):
+        for epoch in range(self.epochs):
             # get the prediction and the fitting residuals
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
 
 
-            optimizer.zero_grad()
+            self.optimizer.zero_grad()
 
 
             # calculate loss from the differential system
             # calculate loss from the differential system
             loss_physics = 0
             loss_physics = 0
@@ -170,11 +214,11 @@ class DINN:
             for i, group in enumerate(self.data.group_names):
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
             
             
-            loss = loss_physics + loss_obs
+            loss = loss_obs + loss_physics
 
 
             loss.backward()
             loss.backward()
-            optimizer.step()
-            scheduler.step()
+            self.optimizer.step()
+            self.scheduler.step()
 
 
             # append values for plotting
             # append values for plotting
             self.losses[epoch] = loss.item()
             self.losses[epoch] = loss.item()
@@ -188,15 +232,18 @@ class DINN:
                 # prediction
                 # prediction
                 prediction = self.model(self.data.t_batch)
                 prediction = self.model(self.data.t_batch)
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-                groups = self.data.get_denormalized_data([prediction[:, 0], prediction[:, 1], prediction[:, 2]])
+                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+                plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
+                background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
                 self.plotter.plot(t, 
                 self.plotter.plot(t, 
                                   list(groups) + list(self.data.data), 
                                   list(groups) + list(self.data.data), 
-                                  [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names],
+                                  plot_labels,
                                   'frame',
                                   'frame',
                                   f'epoch {epoch}',
                                   f'epoch {epoch}',
                                   figure_shape=(12, 6),
                                   figure_shape=(12, 6),
                                   is_frame=True,
                                   is_frame=True,
-                                  is_background=[0, 0, 0, 1, 1, 1],
+                                  is_background=background_list,
                                   lw=3,
                                   lw=3,
                                   legend_loc='upper right',
                                   legend_loc='upper right',
                                   ylim=(0, self.data.N), 
                                   ylim=(0, self.data.N), 
@@ -204,19 +251,23 @@ class DINN:
                                   ylabel='amount of people')
                                   ylabel='amount of people')
 
 
             # print training advancements
             # print training advancements
-            if epoch % 1000 == 0:          
-                print('\nEpoch ', epoch)
+            if epoch % 1000 == 0 and verbose:          
+                print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
                 print(f'loss:\t\t\t{loss.item()}')
                 print(f'loss:\t\t\t{loss.item()}')
                 print('---------------------------------')
                 print('---------------------------------')
-                for parameter in self.parameters_tilda.items():
-                    print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
-                print('#################################')
+                if len(self.parameters_tilda.items()) != 0:
+                    for parameter in self.parameters_tilda.items():
+                        print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
+                    print('#################################')
 
 
         # create prediction animation
         # create prediction animation
         if create_animation:
         if create_animation:
             self.plotter.animate(self.data.name + '_animation')
             self.plotter.animate(self.data.name + '_animation')
+            self.plotter.reset_animation()
+
+        self.__has_trained = True
 
 
     def plot_training_graphs(self, ground_truth=[]):
     def plot_training_graphs(self, ground_truth=[]):
         """Plot the loss graph and the graphs of the advancements of the parameters.
         """Plot the loss graph and the graphs of the advancements of the parameters.
@@ -224,7 +275,7 @@ class DINN:
         Args:
         Args:
             ground_truth (list): List of the ground truth parameters
             ground_truth (list): List of the ground truth parameters
         """
         """
-        assert self.epochs != None
+        assert self.__has_trained, 'Model has to be trained, before plotting the training graphs'
         epochs = np.arange(0, self.epochs, 1)
         epochs = np.arange(0, self.epochs, 1)
 
 
         # plot loss
         # plot loss
@@ -233,8 +284,34 @@ class DINN:
         # plot parameters
         # plot parameters
         for i, parameter in enumerate(self.parameters):
         for i, parameter in enumerate(self.parameters):
             if len(ground_truth) > i:
             if len(ground_truth) > i:
-                self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
+                self.plotter.plot(epochs, 
+                                  [parameter, 
+                                   np.ones_like(epochs) * ground_truth[i]], 
+                                   ['prediction', 'ground truth'], 
+                                   self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
+                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   is_background=[0, 1], 
+                                   xlabel='epochs')
             else:
             else:
-                self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
-
-        
+                self.plotter.plot(epochs, 
+                                  [parameter], 
+                                  ['prediction'], 
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
+                                  list(self.parameters_tilda.items())[i][0], (6,6), 
+                                  xlabel='epochs', 
+                                  plot_legend=False)
+
+    def plot_state_variables(self):
+        for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
+            prediction = self.model(self.data.t_batch)
+            groups = [prediction[:, i] for i in range(self.data.number_groups)]
+            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            self.plotter.plot(t,
+                              [prediction[:, i]] + groups,
+                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
+                              self.__state_variables[i-self.data.number_groups],
+                              is_background=[0, 1, 1],
+                              figure_shape=(12, 6),
+                              plot_legend=True,
+                              xlabel='time / days')

+ 53 - 21
src/problem.py

@@ -9,41 +9,73 @@ class PandemicProblem:
             data (PandemicDataset): Dataset holding the time values used.
             data (PandemicDataset): Dataset holding the time values used.
         """
         """
 
 
-        self.data = data
-        self.device_name = data.device_name
+        self._data = data
+        self._device_name = data.device_name
 
 
-        #store the gradients for each group
-        self.gradients = [torch.zeros((len(data.t_raw), data.number_groups), device=self.device_name) for _ in range(data.number_groups)]
-
-        for i in range(data.number_groups):
-            self.gradients[i][:, i] = 1
+        self._gradients = None
 
 
     def residual(self):
     def residual(self):
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """NEEDS TO BE IMPLEMENTED WHEN INHERITING FROM THIS CLASS
         """
         """
-        pass
+        assert self._gradients != None, 'Gradientmatrix need to be defined'
+        
+
+    def def_grad_matrix(self, number:int):
+        assert self._gradients == None, 'Gradientmatrix is already defined'
+        self._gradients = [torch.zeros((len(self._data.t_raw), number), device=self._device_name) for _ in range(number)]
+        for i in range(number):
+            self._gradients[i][:, i] = 1
 
 
 class SIRProblem(PandemicProblem):
 class SIRProblem(PandemicProblem):
     def __init__(self, data: PandemicDataset):
     def __init__(self, data: PandemicDataset):
         super().__init__(data)
         super().__init__(data)
 
 
     def residual(self, SIR_pred, alpha, beta):
     def residual(self, SIR_pred, alpha, beta):
-        SIR_pred.backward(self.gradients[0], retain_graph=True)
-        dSdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        super().residual()
+        SIR_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
 
-        SIR_pred.backward(self.gradients[1], retain_graph=True)
-        dIdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        SIR_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
 
-        SIR_pred.backward(self.gradients[2], retain_graph=True)
-        dRdt = self.data.t_raw.grad.clone()
-        self.data.t_raw.grad.zero_()
+        SIR_pred.backward(self._gradients[2], retain_graph=True)
+        dRdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
 
 
-        S, I, R = self.data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
+        S, I, _ = self._data.get_denormalized_data([SIR_pred[:, 0], SIR_pred[:, 1], SIR_pred[:, 2]])
 
 
-        S_residual = dSdt - (-beta * ((S * I) / self.data.N)) / (self.data.get_max('S') - self.data.get_min('S'))
-        I_residual = dIdt - (beta * ((S * I) / self.data.N) - alpha * I) / (self.data.get_max('I') - self.data.get_min('I'))
-        R_residual = dRdt - (alpha * I) / (self.data.get_max('R') - self.data.get_min('R'))
+        S_residual = dSdt - (-beta * ((S * I) / self._data.N)) / (self._data.get_max('S') - self._data.get_min('S'))
+        I_residual = dIdt - (beta * ((S * I) / self._data.N) - alpha * I) / (self._data.get_max('I') - self._data.get_min('I'))
+        R_residual = dRdt - (alpha * I) / (self._data.get_max('R') - self._data.get_min('R'))
 
 
         return S_residual, I_residual, R_residual
         return S_residual, I_residual, R_residual
+
+
+class ReducedSIRProblem(PandemicProblem):
+    def __init__(self, data: PandemicDataset, alpha:float):
+        super().__init__(data)
+        self.alpha = alpha
+
+    def residual(self, SI_pred):
+        super().residual()
+        SI_pred.backward(self._gradients[0], retain_graph=True)
+        dSdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        SI_pred.backward(self._gradients[1], retain_graph=True)
+        dIdt = self._data.t_raw.grad.clone()
+        self._data.t_raw.grad.zero_()
+
+        _, I = self._data.get_denormalized_data([SI_pred[:, 0], SI_pred[:, 1]])
+        R_t = SI_pred[:, 2]
+        # I = SI_pred[:, 1]
+
+        S_residual = dSdt - (-self.alpha * R_t * I)
+        I_residual = dIdt - (self.alpha * (R_t - 1) * I)
+
+        # print(f'\nTrue:\tI_min: {I.min()}, I_max: {I.max()}\nNorm:\tI_min: {SI_pred[:, 1].min()}, I_max: {SI_pred[:, 1].max()}\nResidual:\t{torch.mean(torch.square(I_residual))}')
+
+        return S_residual, I_residual
+