瀏覽代碼

add scaling, norm, optimizer and scheduler choosing choosing

phillip.rothenbeck 4 月之前
父節點
當前提交
c38d74dc4c
共有 1 個文件被更改,包括 129 次插入32 次删除
  1. 129 32
      src/dinn.py

+ 129 - 32
src/dinn.py

@@ -1,10 +1,33 @@
 import torch
+import csv
 import numpy as np
 
+from enum import Enum
+
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .plotter import Plotter
 
+class Optimizer(Enum):
+    ADAM=0
+
+class Scheduler(Enum):
+    CYCLIC=0
+    CONSTANT=1
+    LINEAR=2
+    POLYNOMIAL=3
+
+class Activation(Enum):
+    LINEAR=0
+    POWER=1
+
+def linear(x):
+    return x
+        
+def power(x):
+    return torch.float_power(x, 2)
+
+
 class DINN:
     class NN(torch.nn.Module):
         def __init__(self, 
@@ -12,9 +35,12 @@ class DINN:
                      input_size: int,
                      hidden_size: int,
                      hidden_layers: int, 
-                     activation_layer, 
+                     activation_layer,
                      t_init,
-                     t_final) -> None:
+                     t_final,
+                     output_activation_function=Activation.LINEAR,
+                     use_glorot_initialization = False,
+                     use_t_scaled=True) -> None:
             """Neural Network
 
             Args:
@@ -26,21 +52,39 @@ class DINN:
             """
             super(DINN.NN, self).__init__()
 
+            if output_activation_function == Activation.LINEAR:
+                self.out_activation = linear
+            elif output_activation_function == Activation.POWER:
+                self.out_activation = power
+            else:
+                print('Set output activation to default: linear')
+                self.out_activation = self.linear
+
             self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
             self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
+
+            if use_glorot_initialization:
+                torch.nn.init.xavier_uniform_(self.input[0].weight)
+                for i in range(hidden_layers):
+                    torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
+                torch.nn.init.xavier_uniform_(self.output.weight)
             
             self.__t_init = t_init
             self.__t_final = t_final
+            self.__use_t_scaled = use_t_scaled
 
         def forward(self, t):
             # normalize input
-            t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
-            x = self.input(t_scaled)
+            if self.__use_t_scaled:
+                t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
+            else:
+                t_forward = t
+            x = self.input(t_forward)
             x = self.hidden(x)
             x = self.output(x)
-            return x
-
+            return self.out_activation(x)
+    
     def __init__(self,
                  output_size: int,
                  data: PandemicDataset,
@@ -52,7 +96,9 @@ class DINN:
                  input_size=1, 
                  hidden_size=20, 
                  hidden_layers=7, 
-                 activation_layer=torch.nn.ReLU()) -> None:
+                 activation_layer=torch.nn.ReLU(),
+                 activation_output=Activation.LINEAR,
+                 use_glorot_initialization = False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
 
@@ -78,9 +124,12 @@ class DINN:
                              input_size, 
                              hidden_size, 
                              hidden_layers, 
-                             activation_layer, 
+                             activation_layer,
                              data.t_init, 
-                             data.t_final)
+                             data.t_final,
+                             activation_output,
+                             use_glorot_initialization=use_glorot_initialization,
+                             use_t_scaled=data.use_scaled_time)
         self.model = self.model.to(self.device)
         self.data = data
         self.parameter_regulator = parameter_regulator
@@ -131,8 +180,21 @@ class DINN:
             list: list of regulated parameters
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
+
     
-    def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
+    def get_output(self, index):
+        output = self.model(self.data.t_batch)
+        return output[:, index]
+    
+    def configure_training(self, 
+                           lr:float, 
+                           epochs:int, 
+                           optimizer_class=Optimizer.ADAM, 
+                           scheduler_class=Scheduler.CYCLIC, 
+                           scheduler_factor = 1, 
+                           lambda_obs = 1,
+                           lambda_physics = 1,
+                           verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
         Args:
@@ -144,36 +206,38 @@ class DINN:
         """
         parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
         self.epochs = epochs
-        match optimizer_name:
-            case 'Adam':
+        self.lambda_obs = lambda_obs
+        self.lambda_physics = lambda_physics
+        match optimizer_class:
+            case Optimizer.ADAM:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
             case _:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
+                    print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
                     print('---------------------------------')
-                optimizer_name = 'Adam'
+                optimizer_class = Optimizer.ADAM
 
-        match scheduler_name:
-            case 'CyclicLR':
+        match scheduler_class:
+            case Scheduler.CYCLIC:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
-            case 'ConstantLR':
+            case Scheduler.CONSTANT:
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
-            case 'LinearLR':
+            case Scheduler.LINEAR:
                 self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
-            case 'PolynomialLR':
+            case Scheduler.POLYNOMIAL:
                 self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
+                    print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
                     print('---------------------------------')
-                scheduler_name = 'CyclicLR'
+                scheduler_class = Scheduler.CYCLIC
 
         if verbose:
-            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
+            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
 
         self.__is_configured = True
 
@@ -181,7 +245,9 @@ class DINN:
     def train(self, 
               create_animation=False,
               animation_sample_rate=500,
-              verbose=False):
+              verbose=False,
+              do_split_training=False,
+              start_split=10000):
         """Training routine for the DINN.
 
         Args:
@@ -203,20 +269,27 @@ class DINN:
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
-
             self.optimizer.zero_grad()
 
             # calculate loss from the differential system
             loss_physics = 0
             for residual in residuals:
                 loss_physics += torch.mean(torch.square(residual))
+            loss_physics *= self.lambda_physics
 
             # calculate loss from the dataset
             loss_obs = 0
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+            loss_obs *= self.lambda_obs
             
-            loss = loss_obs + loss_physics
+            if do_split_training:
+                if epoch < start_split:
+                    loss = loss_obs
+                else:
+                    loss = loss_obs + loss_physics
+            else:
+                loss = loss_obs + loss_physics
 
             loss.backward()
             self.optimizer.step()
@@ -291,7 +364,8 @@ class DINN:
                                    np.ones_like(epochs) * ground_truth[i]], 
                                    ['prediction', 'ground truth'], 
                                    self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], (6,6), 
+                                   list(self.parameters_tilda.items())[i][0], 
+                                   (6,6), 
                                    is_background=[0, 1], 
                                    xlabel='epochs')
             else:
@@ -302,19 +376,42 @@ class DINN:
                                   list(self.parameters_tilda.items())[i][0], (6,6), 
                                   xlabel='epochs', 
                                   plot_legend=False)
+                
+    def save_training_process(self, title, save_predictions = True):
+        losses = {'loss' : self.losses,
+                  'obs_loss' : self.obs_losses,
+                  'physics_loss' : self.physics_losses}
+        for loss in losses.keys():
+            with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(losses[loss])
+
+        for i, parameter in enumerate(self.parameters):
+            with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
+                writer = csv.writer(csvfile, delimiter=',')
+                writer.writerow(parameter)
+        if save_predictions:
+            prediction = self.model(self.data.t_batch)
+            for i, group in enumerate(self.data.group_names):
+                t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
+                true = self.data.get_group(group).detach().cpu().numpy()
+                pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
+                print(t.shape, true.shape)
+                with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
+                    writer = csv.writer(csvfile, delimiter=',')
+                    writer.writerow(t)
+                    writer.writerow(true)
+                    writer.writerow(pred)
 
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
-        groups = [prediction[:, i] for i in range(self.data.number_groups)]
-        fore_background = [0] + [1 for _ in groups]
         for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
-            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
-                              [prediction[:, i]] + groups,
-                              [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
+                              [prediction[:, i]],
+                              [self.__state_variables[i-self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               self.__state_variables[i-self.data.number_groups],
-                              is_background=fore_background,
                               figure_shape=(12, 6),
                               plot_legend=True,
                               xlabel='time / days')