Browse Source

add scaling, norm, optimizer and scheduler choosing choosing

phillip.rothenbeck 4 months ago
parent
commit
c38d74dc4c
1 changed files with 129 additions and 32 deletions
  1. 129 32
      src/dinn.py

+ 129 - 32
src/dinn.py

@@ -1,10 +1,33 @@
 import torch
 import torch
+import csv
 import numpy as np
 import numpy as np
 
 
+from enum import Enum
+
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .problem import PandemicProblem
 from .plotter import Plotter
 from .plotter import Plotter
 
 
+class Optimizer(Enum):
+    ADAM=0
+
+class Scheduler(Enum):
+    CYCLIC=0
+    CONSTANT=1
+    LINEAR=2
+    POLYNOMIAL=3
+
+class Activation(Enum):
+    LINEAR=0
+    POWER=1
+
+def linear(x):
+    return x
+        
+def power(x):
+    return torch.float_power(x, 2)
+
+
 class DINN:
 class DINN:
     class NN(torch.nn.Module):
     class NN(torch.nn.Module):
         def __init__(self, 
         def __init__(self, 
@@ -12,9 +35,12 @@ class DINN:
                      input_size: int,
                      input_size: int,
                      hidden_size: int,
                      hidden_size: int,
                      hidden_layers: int, 
                      hidden_layers: int, 
-                     activation_layer, 
+                     activation_layer,
                      t_init,
                      t_init,
-                     t_final) -> None:
+                     t_final,
+                     output_activation_function=Activation.LINEAR,
+                     use_glorot_initialization = False,
+                     use_t_scaled=True) -> None:
             """Neural Network
             """Neural Network
 
 
             Args:
             Args:
@@ -26,21 +52,39 @@ class DINN:
             """
             """
             super(DINN.NN, self).__init__()
             super(DINN.NN, self).__init__()
 
 
+            if output_activation_function == Activation.LINEAR:
+                self.out_activation = linear
+            elif output_activation_function == Activation.POWER:
+                self.out_activation = power
+            else:
+                print('Set output activation to default: linear')
+                self.out_activation = self.linear
+
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
             self.output = torch.nn.Linear(hidden_size, output_size)
+
+            if use_glorot_initialization:
+                torch.nn.init.xavier_uniform_(self.input[0].weight)
+                for i in range(hidden_layers):
+                    torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
+                torch.nn.init.xavier_uniform_(self.output.weight)
             
             
             self.__t_init = t_init
             self.__t_init = t_init
             self.__t_final = t_final
             self.__t_final = t_final
+            self.__use_t_scaled = use_t_scaled
 
 
         def forward(self, t):
         def forward(self, t):
             # normalize input
             # normalize input
-            t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
-            x = self.input(t_scaled)
+            if self.__use_t_scaled:
+                t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
+            else:
+                t_forward = t
+            x = self.input(t_forward)
             x = self.hidden(x)
             x = self.hidden(x)
             x = self.output(x)
             x = self.output(x)
-            return x
-
+            return self.out_activation(x)
+    
     def __init__(self,
     def __init__(self,
                  output_size: int,
                  output_size: int,
                  data: PandemicDataset,
                  data: PandemicDataset,
@@ -52,7 +96,9 @@ class DINN:
                  input_size=1, 
                  input_size=1, 
                  hidden_size=20, 
                  hidden_size=20, 
                  hidden_layers=7, 
                  hidden_layers=7, 
-                 activation_layer=torch.nn.ReLU()) -> None:
+                 activation_layer=torch.nn.ReLU(),
+                 activation_output=Activation.LINEAR,
+                 use_glorot_initialization = False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
         parameters of a specific mathematical model.
 
 
@@ -78,9 +124,12 @@ class DINN:
                              input_size, 
                              input_size, 
                              hidden_size, 
                              hidden_size, 
                              hidden_layers, 
                              hidden_layers, 
-                             activation_layer, 
+                             activation_layer,
                              data.t_init, 
                              data.t_init, 
-                             data.t_final)
+                             data.t_final,
+                             activation_output,
+                             use_glorot_initialization=use_glorot_initialization,
+                             use_t_scaled=data.use_scaled_time)
         self.model = self.model.to(self.device)
         self.model = self.model.to(self.device)
         self.data = data
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.parameter_regulator = parameter_regulator
@@ -131,8 +180,21 @@ class DINN:
             list: list of regulated parameters
             list: list of regulated parameters
         """
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
+
     
     
-    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+    def get_output(self, index):
+        output = self.model(self.data.t_batch)
+        return output[:, index]
+    
+    def configure_training(self, 
+                           lr:float, 
+                           epochs:int, 
+                           optimizer_class=Optimizer.ADAM, 
+                           scheduler_class=Scheduler.CYCLIC, 
+                           scheduler_factor = 1, 
+                           lambda_obs = 1,
+                           lambda_physics = 1,
+                           verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
 
         Args:
         Args:
@@ -144,36 +206,38 @@ class DINN:
         """
         """
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         self.epochs = epochs
         self.epochs = epochs
-        match optimizer_name:
-            case 'Adam':
+        self.lambda_obs = lambda_obs
+        self.lambda_physics = lambda_physics
+        match optimizer_class:
+            case Optimizer.ADAM:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
             case _:
             case _:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 if verbose:
                 if verbose:
                     print('---------------------------------')
                     print('---------------------------------')
-                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
                     print('---------------------------------')
                     print('---------------------------------')
-                optimizer_name = 'Adam'
+                optimizer_class = Optimizer.ADAM
 
 
-        match scheduler_name:
-            case 'CyclicLR':
+        match scheduler_class:
+            case Scheduler.CYCLIC:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-            case 'ConstantLR':
+            case Scheduler.CONSTANT:
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
-            case 'LinearLR':
+            case Scheduler.LINEAR:
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
-            case 'PolynomialLR':
+            case Scheduler.POLYNOMIAL:
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
             case _:
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                 if verbose:
                     print('---------------------------------')
                     print('---------------------------------')
-                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
                     print('---------------------------------')
                     print('---------------------------------')
-                scheduler_name = 'CyclicLR'
+                scheduler_class = Scheduler.CYCLIC
 
 
         if verbose:
         if verbose:
-            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
 
 
         self.__is_configured = True
         self.__is_configured = True
 
 
@@ -181,7 +245,9 @@ class DINN:
     def train(self, 
     def train(self, 
               create_animation=False,
               create_animation=False,
               animation_sample_rate=500,
               animation_sample_rate=500,
-              verbose=False):
+              verbose=False,
+              do_split_training=False,
+              start_split=10000):
         """Training routine for the DINN.
         """Training routine for the DINN.
 
 
         Args:
         Args:
@@ -203,20 +269,27 @@ class DINN:
             # get the prediction and the fitting residuals
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
-
             self.optimizer.zero_grad()
             self.optimizer.zero_grad()
 
 
             # calculate loss from the differential system
             # calculate loss from the differential system
             loss_physics = 0
             loss_physics = 0
             for residual in residuals:
             for residual in residuals:
                 loss_physics += torch.mean(torch.square(residual))
                 loss_physics += torch.mean(torch.square(residual))
+            loss_physics *= self.lambda_physics
 
 
             # calculate loss from the dataset
             # calculate loss from the dataset
             loss_obs = 0
             loss_obs = 0
             for i, group in enumerate(self.data.group_names):
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+            loss_obs *= self.lambda_obs
             
             
-            loss = loss_obs + loss_physics
+            if do_split_training:
+                if epoch < start_split:
+                    loss = loss_obs
+                else:
+                    loss = loss_obs + loss_physics
+            else:
+                loss = loss_obs + loss_physics
 
 
             loss.backward()
             loss.backward()
             self.optimizer.step()
             self.optimizer.step()
@@ -291,7 +364,8 @@ class DINN:
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    ['prediction', 'ground truth'], 
                                    ['prediction', 'ground truth'], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   list(self.parameters_tilda.items())[i][0], 
+                                   (6,6), 
                                    is_background=[0, 1], 
                                    is_background=[0, 1], 
                                    xlabel='epochs')
                                    xlabel='epochs')
             else:
             else:
@@ -302,19 +376,42 @@ class DINN:
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   xlabel='epochs', 
                                   xlabel='epochs', 
                                   plot_legend=False)
                                   plot_legend=False)
+                
+    def save_training_process(self, title, save_predictions = True):
+        losses = {'loss' : self.losses,
+                  'obs_loss' : self.obs_losses,
+                  'physics_loss' : self.physics_losses}
+        for loss in losses.keys():
+            with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(losses[loss])
+
+        for i, parameter in enumerate(self.parameters):
+            with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(parameter)
+        if save_predictions:
+            prediction = self.model(self.data.t_batch)
+            for i, group in enumerate(self.data.group_names):
+                t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
+                true = self.data.get_group(group).detach().cpu().numpy()
+                pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
+                print(t.shape, true.shape)
+                with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(t)
+                    writer.writerow(true)
+                    writer.writerow(pred)
 
 
     def plot_state_variables(self):
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
         prediction = self.model(self.data.t_batch)
-        groups = [prediction[:, i] for i in range(self.data.number_groups)]
-        fore_background = [0] + [1 for _ in groups]
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
-            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
             self.plotter.plot(t,
-                              [prediction[:, i]] + groups,
-                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              [prediction[:, i]],
+                              [self.__state_variables[i-self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               self.__state_variables[i-self.data.number_groups],
                               self.__state_variables[i-self.data.number_groups],
-                              is_background=fore_background,
                               figure_shape=(12, 6),
                               figure_shape=(12, 6),
                               plot_legend=True,
                               plot_legend=True,
                               xlabel='time / days')
                               xlabel='time / days')