소스 검색

do only I gif and reformet

phillip.rothenbeck 1 년 전
부모
커밋
c7fc587487
1개의 변경된 파일111개의 추가작업 그리고 70개의 파일을 삭제
  1. 111 70
      src/dinn.py

+ 111 - 70
src/dinn.py

@@ -8,38 +8,43 @@ from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .problem import PandemicProblem
 from .plotter import Plotter
 from .plotter import Plotter
 
 
+
 class Optimizer(Enum):
 class Optimizer(Enum):
-    ADAM=0
+    ADAM = 0
+
 
 
 class Scheduler(Enum):
 class Scheduler(Enum):
-    CYCLIC=0
-    CONSTANT=1
-    LINEAR=2
-    POLYNOMIAL=3
+    CYCLIC = 0
+    CONSTANT = 1
+    LINEAR = 2
+    POLYNOMIAL = 3
+
 
 
 class Activation(Enum):
 class Activation(Enum):
-    LINEAR=0
-    POWER=1
+    LINEAR = 0
+    POWER = 1
+
 
 
 def linear(x):
 def linear(x):
     return x
     return x
-        
+
+
 def power(x):
 def power(x):
     return torch.float_power(x, 2)
     return torch.float_power(x, 2)
 
 
 
 
 class DINN:
 class DINN:
     class NN(torch.nn.Module):
     class NN(torch.nn.Module):
-        def __init__(self, 
+        def __init__(self,
                      output_size: int,
                      output_size: int,
                      input_size: int,
                      input_size: int,
                      hidden_size: int,
                      hidden_size: int,
-                     hidden_layers: int, 
+                     hidden_layers: int,
                      activation_layer,
                      activation_layer,
                      t_init,
                      t_init,
                      t_final,
                      t_final,
                      output_activation_function=Activation.LINEAR,
                      output_activation_function=Activation.LINEAR,
-                     use_glorot_initialization = False,
+                     use_glorot_initialization=False,
                      use_t_scaled=True) -> None:
                      use_t_scaled=True) -> None:
             """Neural Network
             """Neural Network
 
 
@@ -69,7 +74,7 @@ class DINN:
                 for i in range(hidden_layers):
                 for i in range(hidden_layers):
                     torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
                     torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
                 torch.nn.init.xavier_uniform_(self.output.weight)
                 torch.nn.init.xavier_uniform_(self.output.weight)
-            
+
             self.__t_init = t_init
             self.__t_init = t_init
             self.__t_final = t_final
             self.__t_final = t_final
             self.__use_t_scaled = use_t_scaled
             self.__use_t_scaled = use_t_scaled
@@ -84,7 +89,7 @@ class DINN:
             x = self.hidden(x)
             x = self.hidden(x)
             x = self.output(x)
             x = self.output(x)
             return self.out_activation(x)
             return self.out_activation(x)
-    
+
     def __init__(self,
     def __init__(self,
                  output_size: int,
                  output_size: int,
                  data: PandemicDataset,
                  data: PandemicDataset,
@@ -93,12 +98,12 @@ class DINN:
                  plotter: Plotter,
                  plotter: Plotter,
                  state_variables=[],
                  state_variables=[],
                  parameter_regulator=torch.tanh,
                  parameter_regulator=torch.tanh,
-                 input_size=1, 
-                 hidden_size=20, 
-                 hidden_layers=7, 
+                 input_size=1,
+                 hidden_size=20,
+                 hidden_layers=7,
                  activation_layer=torch.nn.ReLU(),
                  activation_layer=torch.nn.ReLU(),
                  activation_output=Activation.LINEAR,
                  activation_output=Activation.LINEAR,
-                 use_glorot_initialization = False) -> None:
+                 use_glorot_initialization=False) -> None:
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the 
         parameters of a specific mathematical model.
         parameters of a specific mathematical model.
 
 
@@ -120,12 +125,12 @@ class DINN:
         self.device_name = data.device_name
         self.device_name = data.device_name
         self.plotter = plotter
         self.plotter = plotter
 
 
-        self.model = DINN.NN(output_size, 
-                             input_size, 
-                             hidden_size, 
-                             hidden_layers, 
+        self.model = DINN.NN(output_size,
+                             input_size,
+                             hidden_size,
+                             hidden_layers,
                              activation_layer,
                              activation_layer,
-                             data.t_init, 
+                             data.t_init,
                              data.t_final,
                              data.t_final,
                              activation_output,
                              activation_output,
                              use_glorot_initialization=use_glorot_initialization,
                              use_glorot_initialization=use_glorot_initialization,
@@ -138,7 +143,7 @@ class DINN:
 
 
         self.parameters_tilda = {}
         self.parameters_tilda = {}
         for parameter in parameter_list:
         for parameter in parameter_list:
-            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
+            self.parameters_tilda.update({parameter: torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
 
 
         # new model has to be configured and then trained
         # new model has to be configured and then trained
         self.__is_configured = False
         self.__is_configured = False
@@ -164,7 +169,7 @@ class DINN:
             torch.Parameter: Regulated parameter object of the search parameter.
             torch.Parameter: Regulated parameter object of the search parameter.
         """
         """
         return self.parameter_regulator(self.parameters_tilda[parameter_name])
         return self.parameter_regulator(self.parameters_tilda[parameter_name])
-    
+
     def get_parameters_tilda(self):
     def get_parameters_tilda(self):
         """Function to get the original value (not forced into any range).
         """Function to get the original value (not forced into any range).
 
 
@@ -181,19 +186,18 @@ class DINN:
         """
         """
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
         return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
 
 
-    
     def get_output(self, index):
     def get_output(self, index):
         output = self.model(self.data.t_batch)
         output = self.model(self.data.t_batch)
         return output[:, index]
         return output[:, index]
-    
-    def configure_training(self, 
-                           lr:float, 
-                           epochs:int, 
-                           optimizer_class=Optimizer.ADAM, 
-                           scheduler_class=Scheduler.CYCLIC, 
-                           scheduler_factor = 1, 
-                           lambda_obs = 1,
-                           lambda_physics = 1,
+
+    def configure_training(self,
+                           lr: float,
+                           epochs: int,
+                           optimizer_class=Optimizer.ADAM,
+                           scheduler_class=Scheduler.CYCLIC,
+                           scheduler_factor=1,
+                           lambda_obs=1,
+                           lambda_physics=1,
                            verbose=False):
                            verbose=False):
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
         """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
 
 
@@ -225,9 +229,9 @@ class DINN:
             case Scheduler.CONSTANT:
             case Scheduler.CONSTANT:
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
                 self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
             case Scheduler.LINEAR:
             case Scheduler.LINEAR:
-                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
+                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
             case Scheduler.POLYNOMIAL:
             case Scheduler.POLYNOMIAL:
-                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
+                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
             case _:
             case _:
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                 if verbose:
@@ -241,8 +245,8 @@ class DINN:
 
 
         self.__is_configured = True
         self.__is_configured = True
 
 
-    
-    def train(self, 
+    def train(self,
+              plot_I_prediction=False,
               create_animation=False,
               create_animation=False,
               animation_sample_rate=500,
               animation_sample_rate=500,
               verbose=False,
               verbose=False,
@@ -258,7 +262,7 @@ class DINN:
         assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
         assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
         if verbose:
         if verbose:
             print(f'torch seed: {torch.seed()}')
             print(f'torch seed: {torch.seed()}')
-        
+
         # arrays to hold values for plotting
         # arrays to hold values for plotting
         self.losses = np.zeros(self.epochs)
         self.losses = np.zeros(self.epochs)
         self.obs_losses = np.zeros(self.epochs)
         self.obs_losses = np.zeros(self.epochs)
@@ -282,7 +286,7 @@ class DINN:
             for i, group in enumerate(self.data.group_names):
             for i, group in enumerate(self.data.group_names):
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
                 loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
             loss_obs *= self.lambda_obs
             loss_obs *= self.lambda_obs
-            
+
             if do_split_training:
             if do_split_training:
                 if epoch < start_split:
                 if epoch < start_split:
                     loss = loss_obs
                     loss = loss_obs
@@ -305,14 +309,14 @@ class DINN:
             # do snapshot for prediction animation
             # do snapshot for prediction animation
             if epoch % animation_sample_rate == 0 and create_animation:
             if epoch % animation_sample_rate == 0 and create_animation:
                 # prediction
                 # prediction
-                prediction = self.model(self.data.t_batch)
+                """prediction = self.model(self.data.t_batch)
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
                 t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
                 groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
                 groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
 
 
                 plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
                 plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
                 background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
                 background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
-                self.plotter.plot(t, 
-                                  list(groups) + list(self.data.data), 
+                self.plotter.plot(t,
+                                  list(groups) + list(self.data.data),
                                   plot_labels,
                                   plot_labels,
                                   'frame',
                                   'frame',
                                   f'epoch {epoch}',
                                   f'epoch {epoch}',
@@ -321,12 +325,30 @@ class DINN:
                                   is_background=background_list,
                                   is_background=background_list,
                                   lw=3,
                                   lw=3,
                                   legend_loc='upper right',
                                   legend_loc='upper right',
-                                  ylim=(0, self.data.N), 
+                                  ylim=(0, self.data.N),
+                                  xlabel='time / days',
+                                  ylabel='amount of people')"""
+                prediction = self.model(self.data.t_batch)
+                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+                plot_labels = ['I_pred', 'I_true']
+                background_list = [0, 1]
+                self.plotter.plot(t,
+                                  [groups[1]] + [self.data.data[1]],
+                                  plot_labels,
+                                  'Frame',
+                                  f'epoch {epoch}',
+                                  figure_shape=(12, 6),
+                                  is_frame=True,
+                                  is_background=background_list,
+                                  lw=3,
+                                  legend_loc='upper right',
                                   xlabel='time / days',
                                   xlabel='time / days',
                                   ylabel='amount of people')
                                   ylabel='amount of people')
 
 
             # print training advancements
             # print training advancements
-            if epoch % 1000 == 0 and verbose:          
+            if epoch % 1000 == 0 and verbose:
                 print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
@@ -342,6 +364,25 @@ class DINN:
             self.plotter.animate(self.data.name + '_animation')
             self.plotter.animate(self.data.name + '_animation')
             self.plotter.reset_animation()
             self.plotter.reset_animation()
 
 
+        if plot_I_prediction:
+            prediction = self.model(self.data.t_batch)
+            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+            groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+
+            plot_labels = ['I_pred', 'I_true']
+            background_list = [0, 1]
+            self.plotter.plot(t,
+                              [groups[1]] + [self.data.data[1]],
+                              plot_labels,
+                              'Training_I_prediction',
+                              f'Prediction of I on JH data',
+                              figure_shape=(12, 6),
+                              is_background=background_list,
+                              lw=3,
+                              legend_loc='upper right',
+                              xlabel='time / days',
+                              ylabel='amount of people')
+
         self.__has_trained = True
         self.__has_trained = True
 
 
     def plot_training_graphs(self, ground_truth=[]):
     def plot_training_graphs(self, ground_truth=[]):
@@ -355,32 +396,32 @@ class DINN:
 
 
         # plot loss
         # plot loss
         self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
         self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
-        
+
         # plot parameters
         # plot parameters
         for i, parameter in enumerate(self.parameters):
         for i, parameter in enumerate(self.parameters):
             if len(ground_truth) > i:
             if len(ground_truth) > i:
-                self.plotter.plot(epochs, 
-                                  [parameter, 
-                                   np.ones_like(epochs) * ground_truth[i]], 
-                                   ['prediction', 'ground truth'], 
-                                   self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                   list(self.parameters_tilda.items())[i][0], 
-                                   (6,6), 
-                                   is_background=[0, 1], 
-                                   xlabel='epochs')
+                self.plotter.plot(epochs,
+                                  [parameter,
+                                   np.ones_like(epochs) * ground_truth[i]],
+                                  ['prediction', 'ground truth'],
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
+                                  list(self.parameters_tilda.items())[i][0],
+                                  (6, 6),
+                                  is_background=[0, 1],
+                                  xlabel='epochs')
             else:
             else:
-                self.plotter.plot(epochs, 
-                                  [parameter], 
-                                  ['prediction'], 
-                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0], 
-                                  list(self.parameters_tilda.items())[i][0], (6,6), 
-                                  xlabel='epochs', 
+                self.plotter.plot(epochs,
+                                  [parameter],
+                                  ['prediction'],
+                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
+                                  list(self.parameters_tilda.items())[i][0], (6, 6),
+                                  xlabel='epochs',
                                   plot_legend=False)
                                   plot_legend=False)
-                
-    def save_training_process(self, title, save_predictions = True):
-        losses = {'loss' : self.losses,
-                  'obs_loss' : self.obs_losses,
-                  'physics_loss' : self.physics_losses}
+
+    def save_training_process(self, title, save_predictions=True):
+        losses = {'loss': self.losses,
+                  'obs_loss': self.obs_losses,
+                  'physics_loss': self.physics_losses}
         for loss in losses.keys():
         for loss in losses.keys():
             with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
             with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
                 writer = csv.writer(csvfile, delimiter=',')
                 writer = csv.writer(csvfile, delimiter=',')
@@ -405,13 +446,13 @@ class DINN:
 
 
     def plot_state_variables(self):
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
         prediction = self.model(self.data.t_batch)
-        for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
+        for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
             t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
             self.plotter.plot(t,
                               [prediction[:, i]],
                               [prediction[:, i]],
-                              [self.__state_variables[i-self.data.number_groups]],
+                              [self.__state_variables[i - self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
-                              self.__state_variables[i-self.data.number_groups],
+                              self.__state_variables[i - self.data.number_groups],
                               figure_shape=(12, 6),
                               figure_shape=(12, 6),
                               plot_legend=True,
                               plot_legend=True,
-                              xlabel='time / days')
+                              xlabel='time / days')