Forráskód Böngészése

add animation function

phillip.rothenbeck 1 éve
szülő
commit
c5ce9366cb
1 módosított fájl, 51 hozzáadás és 58 törlés
  1. 51 58
      src/dinn.py

+ 51 - 58
src/dinn.py

@@ -1,15 +1,15 @@
 import torch
+import os
+import imageio
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib import rcParams
 
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
+from .plotter import Plotter
+
 
-FONT_COLOR = '#595959'
-SUSCEPTIBLE = '#6399f7'
-INFECTIOUS = '#f56262'
-REMOVED = '#83eb5e'
 
 class DINN:
     class NN(torch.nn.Module):
@@ -45,6 +45,7 @@ class DINN:
                  data: PandemicDataset,
                  parameter_list: list,
                  problem: PandemicProblem,
+                 plotter: Plotter,
                  parameter_regulator=torch.tanh,
                  input_size=1,
                  hidden_size=20,
@@ -65,22 +66,28 @@ class DINN:
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
         """
         
-        self.device = torch.device('cpu')
+        self.device = torch.device(data.device_name)
+        self.device_name = data.device_name
+        self.plotter = plotter
 
         self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
+        self.model = self.model.to(self.device)
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.problem = problem
 
         self.parameters_tilda = {}
         for parameter in parameter_list:
-            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True))})
+            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
         
         self.epochs = None
 
         self.losses = np.zeros(1)
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
 
+        self.frames = []
+
+
     def get_regulated_param(self, parameter_name: str):
         """Function to get the searched parameters, forced into a certain range.
 
@@ -111,13 +118,17 @@ class DINN:
     def train(self, 
               epochs: int, 
               lr: float, 
-              optimizer_class=torch.optim.Adam):
+              optimizer_class=torch.optim.Adam,
+              create_animation=False,
+              animation_sample_rate=500):
         """Training routine for the DINN
 
         Args:
             epochs (int): Number of epochs the NN is supposed to be trained for.
             lr (float): Learning rate for the optimizer.
             optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
+            create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
+            animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
         """
 
         # define optimizer and scheduler
@@ -133,7 +144,7 @@ class DINN:
         for epoch in range(epochs):
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
-            residuals = self.problem.residual(prediction, self.data, *self.get_regulated_param_list())
+            residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
 
             optimizer.zero_grad()
 
@@ -158,6 +169,27 @@ class DINN:
             for i, parameter in enumerate(self.parameters_tilda.items()):
                 self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
 
+            # do snapshot for prediction animation
+            if epoch % animation_sample_rate == 0 and create_animation:
+                # prediction
+                prediction = self.model(self.data.t_batch)
+                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+                
+                groups = self.problem.denormalization(prediction)
+                self.plotter.plot(t, 
+                                  groups + tuple(self.data.get_data()), 
+                                  [name + '_pred' for name in self.data.get_group_names()] + [name + '_true' for name in self.data.get_group_names()],
+                                  'frame',
+                                  f'epoch {epoch}',
+                                  figure_shape=(12, 6),
+                                  is_frame=True,
+                                  is_background=[0, 0, 0, 1, 1, 1],
+                                  lw=3,
+                                  legend_loc='upper right',
+                                  ylim=(0, self.data.N), 
+                                  xlabel='time / days',
+                                  ylabel='amount of people')
+
             # print training advancements
             if epoch % 1000 == 0:          
                 print('\nEpoch ', epoch)
@@ -167,7 +199,11 @@ class DINN:
                 print('---------------------------------')
                 for parameter in self.parameters_tilda.items():
                     print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
-                print('#################################') 
+                print('#################################')
+
+        # create prediction animation
+        if create_animation:
+            self.plotter.animate(self.data.name + '_animation')
 
     def plot_training_graphs(self, ground_truth=[]):
         """Plot the loss graph and the graphs of the advancements of the parameters.
@@ -178,57 +214,14 @@ class DINN:
         assert self.epochs != None
         epochs = np.arange(0, self.epochs, 1)
 
-        rcParams['font.family'] = 'Comfortaa'
-        rcParams['font.size'] = 12
-
-        rcParams['text.color'] = FONT_COLOR
-        rcParams['axes.labelcolor'] = FONT_COLOR
-        rcParams['xtick.color'] = FONT_COLOR
-        rcParams['ytick.color'] = FONT_COLOR
-
         # plot loss
-        plt.plot(epochs, self.losses)
-        plt.title('Loss')
-        plt.yscale('log')
-        plt.show()
-
+        self.plotter.plot(epochs, [self.losses], ['loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=False, xlabel='epochs')
+        
         # plot parameters
         for i, parameter in enumerate(self.parameters):
-            figure = plt.figure(figsize=(6,6))
-            ax = figure.add_subplot(111, facecolor='#dddddd', axisbelow=True)
-            ax.set_facecolor('xkcd:white')
-
-            ax.plot(epochs, parameter, c=FONT_COLOR, lw=3, label='prediction')
             if len(ground_truth) > i:
-                ax.axhline(y=ground_truth[i], color=INFECTIOUS, linestyle='-', lw=3, label='ground truth')
-        
-            plt.xlabel('epochs')
-            plt.title(list(self.parameters_tilda.items())[i][0])
-            ax.yaxis.set_tick_params(length=0)
-
-            ax.yaxis.set_tick_params(length=0, which='both')
-            ax.xaxis.set_tick_params(length=0, which='both')
-            ax.grid(which='major', c='black', lw=0.2, ls='-')
-
-            plt.legend()
-
-            for spine in ('top', 'right', 'bottom', 'left'):
-                ax.spines[spine].set_visible(False)
-
-            figure.savefig(f'visualizations/{list(self.parameters_tilda.items())[i][0]}.png', transparent=True)
-            plt.show()
-
-    def to_cuda(self):
-        """Move training to cuda device
-        """
-        assert torch.cuda.is_available(), "CUDA is not available"
-        
-        self.device = torch.device('cuda')
-        self.model = self.model.to(self.device)
-        self.data.to_device(self.device)
-        self.problem.to_device(self.device)
-
-        for parameter in self.parameters_tilda:
-            self.parameters_tilda[parameter] = self.parameters_tilda[parameter].to(self.device).detach().requires_grad_()
-        
+                self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
+            else:
+                self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
 
+