phillip.rothenbeck 4 ay önce
ebeveyn
işleme
b5e9e225f3
1 değiştirilmiş dosya ile 58 ekleme ve 68 silme
  1. 58 68
      src/dinn.py

+ 58 - 68
src/dinn.py

@@ -65,8 +65,10 @@ class DINN:
                 print('Set output activation to default: linear')
                 self.out_activation = self.linear
 
-            self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
-            self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
+            self.input = torch.nn.Sequential(torch.nn.Linear(
+                input_size, hidden_size), activation_layer)
+            self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(
+                hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
             self.output = torch.nn.Linear(hidden_size, output_size)
 
             if use_glorot_initialization:
@@ -82,7 +84,8 @@ class DINN:
         def forward(self, t):
             # normalize input
             if self.__use_t_scaled:
-                t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
+                t_forward = (t - self.__t_init) / \
+                    (self.__t_final - self.__t_init)
             else:
                 t_forward = t
             x = self.input(t_forward)
@@ -120,7 +123,8 @@ class DINN:
             hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
         """
-        assert len(state_variables) + data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
+        assert len(state_variables) + \
+            data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
         self.device = torch.device(data.device_name)
         self.device_name = data.device_name
         self.plotter = plotter
@@ -143,7 +147,8 @@ class DINN:
 
         self.parameters_tilda = {}
         for parameter in parameter_list:
-            self.parameters_tilda.update({parameter: torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
+            self.parameters_tilda.update({parameter: torch.nn.Parameter(
+                torch.rand(1, requires_grad=True, device=self.device_name))})
 
         # new model has to be configured and then trained
         self.__is_configured = False
@@ -208,7 +213,8 @@ class DINN:
             scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
             verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
         """
-        parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
+        parameter_list = list(self.model.parameters()) + \
+            list(self.parameters_tilda.values())
         self.epochs = epochs
         self.lambda_obs = lambda_obs
         self.lambda_physics = lambda_physics
@@ -219,34 +225,41 @@ class DINN:
                 self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
+                    print(
+                        f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
                     print('---------------------------------')
                 optimizer_class = Optimizer.ADAM
 
         match scheduler_class:
             case Scheduler.CYCLIC:
-                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(
+                    self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
             case Scheduler.CONSTANT:
-                self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
+                self.scheduler = torch.optim.lr_scheduler.ConstantLR(
+                    self.optimizer, factor=1, total_iters=4)
             case Scheduler.LINEAR:
-                self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
+                self.scheduler = torch.optim.lr_scheduler.LinearLR(
+                    self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
             case Scheduler.POLYNOMIAL:
-                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
+                self.scheduler = torch.optim.lr_scheduler.PolynomialLR(
+                    self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
             case _:
-                self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
+                self.scheduler = torch.optim.lr_scheduler.CyclicLR(
+                    self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
                 if verbose:
                     print('---------------------------------')
-                    print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
+                    print(
+                        f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
                     print('---------------------------------')
                 scheduler_class = Scheduler.CYCLIC
 
         if verbose:
-            print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
+            print(
+                f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
 
         self.__is_configured = True
 
     def train(self,
-              plot_I_prediction=False,
               create_animation=False,
               animation_sample_rate=500,
               verbose=False,
@@ -272,7 +285,8 @@ class DINN:
         for epoch in range(self.epochs):
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
-            residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
+            residuals = self.problem.residual(
+                prediction, *self.get_regulated_param_list())
             self.optimizer.zero_grad()
 
             # calculate loss from the differential system
@@ -284,7 +298,8 @@ class DINN:
             # calculate loss from the dataset
             loss_obs = 0
             for i, group in enumerate(self.data.group_names):
-                loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
+                loss_obs += torch.mean(torch.square(
+                    self.data.get_norm(group) - prediction[:, i]))
             loss_obs *= self.lambda_obs
 
             if do_split_training:
@@ -304,33 +319,17 @@ class DINN:
             self.obs_losses[epoch] = loss_obs.item()
             self.physics_losses[epoch] = loss_physics.item()
             for i, parameter in enumerate(self.parameters_tilda.items()):
-                self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
+                self.parameters[i][epoch] = self.get_regulated_param(
+                    parameter[0]).item()
 
             # do snapshot for prediction animation
             if epoch % animation_sample_rate == 0 and create_animation:
                 # prediction
-                """prediction = self.model(self.data.t_batch)
-                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
-
-                plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
-                background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
-                self.plotter.plot(t,
-                                  list(groups) + list(self.data.data),
-                                  plot_labels,
-                                  'frame',
-                                  f'epoch {epoch}',
-                                  figure_shape=(12, 6),
-                                  is_frame=True,
-                                  is_background=background_list,
-                                  lw=3,
-                                  legend_loc='upper right',
-                                  ylim=(0, self.data.N),
-                                  xlabel='time / days',
-                                  ylabel='amount of people')"""
                 prediction = self.model(self.data.t_batch)
-                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-                groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
+                t = torch.arange(
+                    0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+                groups = self.data.get_denormalized_data(
+                    [prediction[:, i] for i in range(self.data.number_groups)])
 
                 plot_labels = ['I_pred', 'I_true']
                 background_list = [0, 1]
@@ -349,14 +348,16 @@ class DINN:
 
             # print training advancements
             if epoch % 1000 == 0 and verbose:
-                print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
+                print(
+                    f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
                 print(f'physics loss:\t\t{loss_physics.item()}')
                 print(f'observation loss:\t{loss_obs.item()}')
                 print(f'loss:\t\t\t{loss.item()}')
                 print('---------------------------------')
                 if len(self.parameters_tilda.items()) != 0:
                     for parameter in self.parameters_tilda.items():
-                        print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
+                        print(
+                            f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
                     print('#################################')
 
         # create prediction animation
@@ -364,25 +365,6 @@ class DINN:
             self.plotter.animate(self.data.name + '_animation')
             self.plotter.reset_animation()
 
-        if plot_I_prediction:
-            prediction = self.model(self.data.t_batch)
-            t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
-            groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
-
-            plot_labels = ['I_pred', 'I_true']
-            background_list = [0, 1]
-            self.plotter.plot(t,
-                              [groups[1]] + [self.data.data[1]],
-                              plot_labels,
-                              'Training_I_prediction',
-                              f'Prediction of I on JH data',
-                              figure_shape=(12, 6),
-                              is_background=background_list,
-                              lw=3,
-                              legend_loc='upper right',
-                              xlabel='time / days',
-                              ylabel='amount of people')
-
         self.__has_trained = True
 
     def plot_training_graphs(self, ground_truth=[]):
@@ -395,7 +377,8 @@ class DINN:
         epochs = np.arange(0, self.epochs, 1)
 
         # plot loss
-        self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
+        self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss',
+                          'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
 
         # plot parameters
         for i, parameter in enumerate(self.parameters):
@@ -404,7 +387,8 @@ class DINN:
                                   [parameter,
                                    np.ones_like(epochs) * ground_truth[i]],
                                   ['prediction', 'ground truth'],
-                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
+                                  self.data.name + '_' +
+                                  list(self.parameters_tilda.items())[i][0],
                                   list(self.parameters_tilda.items())[i][0],
                                   (6, 6),
                                   is_background=[0, 1],
@@ -413,8 +397,10 @@ class DINN:
                 self.plotter.plot(epochs,
                                   [parameter],
                                   ['prediction'],
-                                  self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
-                                  list(self.parameters_tilda.items())[i][0], (6, 6),
+                                  self.data.name + '_' +
+                                  list(self.parameters_tilda.items())[i][0],
+                                  list(self.parameters_tilda.items())[
+                                      i][0], (6, 6),
                                   xlabel='epochs',
                                   plot_legend=False)
 
@@ -434,9 +420,11 @@ class DINN:
         if save_predictions:
             prediction = self.model(self.data.t_batch)
             for i, group in enumerate(self.data.group_names):
-                t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
+                t = torch.linspace(
+                    0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
                 true = self.data.get_group(group).detach().cpu().numpy()
-                pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
+                pred = self.data.get_denormalized_data([prediction[:, i]])[
+                    0].detach().cpu().numpy()
                 print(t.shape, true.shape)
                 with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
                     writer = csv.writer(csvfile, delimiter=',')
@@ -447,12 +435,14 @@ class DINN:
     def plot_state_variables(self):
         prediction = self.model(self.data.t_batch)
         for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
-            t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
+            t = torch.linspace(
+                0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
             self.plotter.plot(t,
                               [prediction[:, i]],
                               [self.__state_variables[i - self.data.number_groups]],
                               f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
-                              self.__state_variables[i - self.data.number_groups],
+                              self.__state_variables[i -
+                                                     self.data.number_groups],
                               figure_shape=(12, 6),
                               plot_legend=True,
                               xlabel='time / days')