Selaa lähdekoodia

do only I gif and reformet

phillip.rothenbeck 1 vuosi sitten
vanhempi
commit
c7fc587487
1 muutettua tiedostoa jossa 111 lisäystä ja 70 poistoa
  1. 111 70
      src/dinn.py

+ 111 - 70
src/dinn.py

@@ -8,38 +8,43 @@ from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .plotter import Plotter
 
+
 class Optimizer(Enum):
-    ADAM=0
+    ADAM = 0
+
 
 class Scheduler(Enum):
-    CYCLIC=0
-    CONSTANT=1
-    LINEAR=2
-    POLYNOMIAL=3
+    CYCLIC = 0
+    CONSTANT = 1
+    LINEAR = 2
+    POLYNOMIAL = 3
+
 
 class Activation(Enum):
-    LINEAR=0
-    POWER=1
+    LINEAR = 0
+    POWER = 1
+
 
 def linear(x):
     return x
-        
+
+
 def power(x):
     return torch.float_power(x, 2)
 
 
 class DINN:
     class NN(torch.nn.Module):
-        def __init__(self, 
+        def __init__(self,
                      output_size: int,
                      input_size: int,
                      hidden_size: int,
-                     hidden_layers: int, 
+                     hidden_layers: int,
                      activation_layer,
                      t_init,
                      t_final,
                      output_activation_function=Activation.LINEAR,
-                     use_glorot_initialization = False,
+                     use_glorot_initialization=False,
                      use_t_scaled=True) -> None:
             """Neural Network
 
@@ -69,7 +74,7 @@ class DINN:
                 for i in range(hidden_layers):
                     torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
                 torch.nn.init.xavier_uniform_(self.output.weight)
-            
+
             self.__t_init = t_init
             self.__t_final = t_final
             self.__use_t_scaled = use_t_scaled
@@ -84,7 +89,7 @@ class DINN:
             x = self.hidden(x)
             x = self.output(x)
             return self.out_activation(x)
-    
+
     def __init__(self,
                  output_size: int,
                  data: PandemicDataset,
@@ -93,12 +98,12 @@ class DINN:
                  plotter: Plotter,
                  state_variables=[],
                  parameter_regulator=torch.tanh,
-                 input_size=1, 
-                 hidden_size=20, 
-                 hidden_layers=7, 
+                 input_size=1,
+                 hidden_size=20,
+                 hidden_layers=7,
                  activation_layer=torch.nn.ReLU(),
                  activation_output=Activation.LINEAR,
-                 use_glorot_initialization = False) -> None:
+                 use_glorot_initialization=False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
 
@@ -120,12 +125,12 @@ class DINN:
         self.device_name = data.device_name
         self.plotter = plotter
 
-        self.model = DINN.NN(output_size, 
-                             input_size, 
-                             hidden_size, 
-                             hidden_layers, 
+        self.model = DINN.NN(output_size,
+                             input_size,
+                             hidden_size,
+                             hidden_layers,
                              activation_layer,
-                             data.t_init, 
+                             data.t_init,
                              data.t_final,
                              activation_output,
                              use_glorot_initialization=use_glorot_initialization,
@@ -138,7 +143,7 @@ class DINN:
 
         self.parameters_tilda = {}
         for parameter in parameter_list:
-            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
+            self.parameters_tilda.update({parameter: torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
 
         # new model has to be configured and then trained
         self.__is_configured = False
@@ -164,7 +169,7 @@ class DINN:
             torch.Parameter: Regulated parameter object of the search parameter.
         """
         return self.parameter_regulator(self.parameters_tilda[parameter_name])
-    
+
     def get_parameters_tilda(self):
         """Function to get the original value (not forced into any range).
 
@@ -181,19 +186,18 @@ class DINN:
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
 
-    
     def get_output(self, index):
         output = self.model(self.data.t_batch)
         return output[:, index]
-    
-    def configure_training(self, 
-                           lr:float, 
-                           epochs:int, 
-                           optimizer_class=Optimizer.ADAM, 
-                           scheduler_class=Scheduler.CYCLIC, 
-                           scheduler_factor = 1, 
-                           lambda_obs = 1,
-                           lambda_physics = 1,
+
+    def configure_training(self,
+                           lr: float,
+                           epochs: int,
+                           optimizer_class=Optimizer.ADAM,
+                           scheduler_class=Scheduler.CYCLIC,
+                           scheduler_factor=1,
+                           lambda_obs=1,
+                           lambda_physics=1,
                            verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
@@ -225,9 +229,9 @@ class DINN:
             case Scheduler.CONSTANT:
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
             case Scheduler.LINEAR:
-                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
+                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
             case Scheduler.POLYNOMIAL:
-                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
+                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
@@ -241,8 +245,8 @@ class DINN:
 
         self.__is_configured = True
 
-    
-    def train(self, 
+    def train(self,
+              plot_I_prediction=False,
               create_animation=False,
               animation_sample_rate=500,
               verbose=False,
@@ -258,7 +262,7 @@ class DINN:
         assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
         if verbose:
             print(f'torch seed: {torch.seed()}')
-        
+
         # arrays to hold values for plotting
         self.losses = np.zeros(self.epochs)
         self.obs_losses = np.zeros(self.epochs)
@@ -282,7 +286,7 @@ class DINN:
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
             loss_obs *= self.lambda_obs
-            
+
             if do_split_training:
                 if epoch < start_split:
                     loss = loss_obs
@@ -305,14 +309,14 @@ class DINN:
             # do snapshot for prediction animation
             if epoch % animation_sample_rate == 0 and create_animation:
                 # prediction
-                prediction = self.model(self.data.t_batch)
+                """prediction = self.model(self.data.t_batch)
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
                 groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
 
                 plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
                 background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
-                self.plotter.plot(t, 
-                                  list(groups) + list(self.data.data), 
+                self.plotter.plot(t,
+                                  list(groups) + list(self.data.data),
                                   plot_labels,
                                   'frame',
                                   f'epoch {epoch}',
@@ -321,12 +325,30 @@ class DINN:
                                   is_background=background_list,
                                   lw=3,
                                   legend_loc='upper right',
-                                  ylim=(0, self.data.N), 
+                                  ylim=(0, self.data.N),
+                                  xlabel='time / days',
+                                  ylabel='amount of people')"""
+                prediction = self.model(self.data.t_batch)
+                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+                plot_labels = ['I_pred', 'I_true']
+                background_list = [0, 1]
+                self.plotter.plot(t,
+                                  [groups[1]] + [self.data.data[1]],
+                                  plot_labels,
+                                  'Frame',
+                                  f'epoch {epoch}',
+                                  figure_shape=(12, 6),
+                                  is_frame=True,
+                                  is_background=background_list,
+                                  lw=3,
+                                  legend_loc='upper right',
                                   xlabel='time / days',
                                   ylabel='amount of people')
 
             # print training advancements
-            if epoch % 1000 == 0 and verbose:          
+            if epoch % 1000 == 0 and verbose:
                 print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
@@ -342,6 +364,25 @@ class DINN:
             self.plotter.animate(self.data.name + '_animation')
             self.plotter.reset_animation()
 
+        if plot_I_prediction:
+            prediction = self.model(self.data.t_batch)
+            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+            plot_labels = ['I_pred', 'I_true']
+            background_list = [0, 1]
+            self.plotter.plot(t,
+                              [groups[1]] + [self.data.data[1]],
+                              plot_labels,
+                              'Training_I_prediction',
+                              f'Prediction of I on JH data',
+                              figure_shape=(12, 6),
+                              is_background=background_list,
+                              lw=3,
+                              legend_loc='upper right',
+                              xlabel='time / days',
+                              ylabel='amount of people')
+
         self.__has_trained = True
 
     def plot_training_graphs(self, ground_truth=[]):
@@ -355,32 +396,32 @@ class DINN:
 
         # plot loss
         self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
-        
+
         # plot parameters
         for i, parameter in enumerate(self.parameters):
             if len(ground_truth) > i:
-                self.plotter.plot(epochs, 
-                                  [parameter, 
-                                   np.ones_like(epochs) * ground_truth[i]], 
-                                   ['prediction', 'ground truth'], 
-                                   self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], 
-                                   (6,6), 
-                                   is_background=[0, 1], 
-                                   xlabel='epochs')
+                self.plotter.plot(epochs,
+                                  [parameter,
+                                   np.ones_like(epochs) * ground_truth[i]],
+                                  ['prediction', 'ground truth'],
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
+                                  list(self.parameters_tilda.items())[i][0],
+                                  (6, 6),
+                                  is_background=[0, 1],
+                                  xlabel='epochs')
             else:
-                self.plotter.plot(epochs, 
-                                  [parameter], 
-                                  ['prediction'], 
-                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                  list(self.parameters_tilda.items())[i][0], (6,6), 
-                                  xlabel='epochs', 
+                self.plotter.plot(epochs,
+                                  [parameter],
+                                  ['prediction'],
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
+                                  list(self.parameters_tilda.items())[i][0], (6, 6),
+                                  xlabel='epochs',
                                   plot_legend=False)
-                
-    def save_training_process(self, title, save_predictions = True):
-        losses = {'loss' : self.losses,
-                  'obs_loss' : self.obs_losses,
-                  'physics_loss' : self.physics_losses}
+
+    def save_training_process(self, title, save_predictions=True):
+        losses = {'loss': self.losses,
+                  'obs_loss': self.obs_losses,
+                  'physics_loss': self.physics_losses}
         for loss in losses.keys():
             with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
                 writer = csv.writer(csvfile, delimiter=',')
@@ -405,13 +446,13 @@ class DINN:
 
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
-        for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
+        for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
             t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
                               [prediction[:, i]],
-                              [self.__state_variables[i-self.data.number_groups]],
+                              [self.__state_variables[i - self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
-                              self.__state_variables[i-self.data.number_groups],
+                              self.__state_variables[i - self.data.number_groups],
                               figure_shape=(12, 6),
                               plot_legend=True,
-                              xlabel='time / days')
+                              xlabel='time / days')