浏览代码

add animation function

phillip.rothenbeck 1 年之前
父节点
当前提交
c5ce9366cb
共有 1 个文件被更改,包括 51 次插入58 次删除
  1. 51 58
      src/dinn.py

+ 51 - 58
src/dinn.py

@@ -1,15 +1,15 @@
 import torch
 import torch
+import os
+import imageio
 import numpy as np
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 from matplotlib import rcParams
 from matplotlib import rcParams
 
 
 from .dataset import PandemicDataset
 from .dataset import PandemicDataset
 from .problem import PandemicProblem
 from .problem import PandemicProblem
+from .plotter import Plotter
+
 
 
-FONT_COLOR = '#595959'
-SUSCEPTIBLE = '#6399f7'
-INFECTIOUS = '#f56262'
-REMOVED = '#83eb5e'
 
 
 class DINN:
 class DINN:
     class NN(torch.nn.Module):
     class NN(torch.nn.Module):
@@ -45,6 +45,7 @@ class DINN:
                  data: PandemicDataset,
                  data: PandemicDataset,
                  parameter_list: list,
                  parameter_list: list,
                  problem: PandemicProblem,
                  problem: PandemicProblem,
+                 plotter: Plotter,
                  parameter_regulator=torch.tanh,
                  parameter_regulator=torch.tanh,
                  input_size=1,
                  input_size=1,
                  hidden_size=20,
                  hidden_size=20,
@@ -65,22 +66,28 @@ class DINN:
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
             activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
         """
         """
         
         
-        self.device = torch.device('cpu')
+        self.device = torch.device(data.device_name)
+        self.device_name = data.device_name
+        self.plotter = plotter
 
 
         self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
         self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
+        self.model = self.model.to(self.device)
         self.data = data
         self.data = data
         self.parameter_regulator = parameter_regulator
         self.parameter_regulator = parameter_regulator
         self.problem = problem
         self.problem = problem
 
 
         self.parameters_tilda = {}
         self.parameters_tilda = {}
         for parameter in parameter_list:
         for parameter in parameter_list:
-            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True))})
+            self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
         
         
         self.epochs = None
         self.epochs = None
 
 
         self.losses = np.zeros(1)
         self.losses = np.zeros(1)
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
         self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
 
 
+        self.frames = []
+
+
     def get_regulated_param(self, parameter_name: str):
     def get_regulated_param(self, parameter_name: str):
         """Function to get the searched parameters, forced into a certain range.
         """Function to get the searched parameters, forced into a certain range.
 
 
@@ -111,13 +118,17 @@ class DINN:
     def train(self, 
     def train(self, 
               epochs: int, 
               epochs: int, 
               lr: float, 
               lr: float, 
-              optimizer_class=torch.optim.Adam):
+              optimizer_class=torch.optim.Adam,
+              create_animation=False,
+              animation_sample_rate=500):
         """Training routine for the DINN
         """Training routine for the DINN
 
 
         Args:
         Args:
             epochs (int): Number of epochs the NN is supposed to be trained for.
             epochs (int): Number of epochs the NN is supposed to be trained for.
             lr (float): Learning rate for the optimizer.
             lr (float): Learning rate for the optimizer.
             optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
             optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
+            create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
+            animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
         """
         """
 
 
         # define optimizer and scheduler
         # define optimizer and scheduler
@@ -133,7 +144,7 @@ class DINN:
         for epoch in range(epochs):
         for epoch in range(epochs):
             # get the prediction and the fitting residuals
             # get the prediction and the fitting residuals
             prediction = self.model(self.data.t_batch)
             prediction = self.model(self.data.t_batch)
-            residuals = self.problem.residual(prediction, self.data, *self.get_regulated_param_list())
+            residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
 
 
             optimizer.zero_grad()
             optimizer.zero_grad()
 
 
@@ -158,6 +169,27 @@ class DINN:
             for i, parameter in enumerate(self.parameters_tilda.items()):
             for i, parameter in enumerate(self.parameters_tilda.items()):
                 self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
                 self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
 
 
+            # do snapshot for prediction animation
+            if epoch % animation_sample_rate == 0 and create_animation:
+                # prediction
+                prediction = self.model(self.data.t_batch)
+                t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
+                
+                groups = self.problem.denormalization(prediction)
+                self.plotter.plot(t, 
+                                  groups + tuple(self.data.get_data()), 
+                                  [name + '_pred' for name in self.data.get_group_names()] + [name + '_true' for name in self.data.get_group_names()],
+                                  'frame',
+                                  f'epoch {epoch}',
+                                  figure_shape=(12, 6),
+                                  is_frame=True,
+                                  is_background=[0, 0, 0, 1, 1, 1],
+                                  lw=3,
+                                  legend_loc='upper right',
+                                  ylim=(0, self.data.N), 
+                                  xlabel='time / days',
+                                  ylabel='amount of people')
+
             # print training advancements
             # print training advancements
             if epoch % 1000 == 0:          
             if epoch % 1000 == 0:          
                 print('\nEpoch ', epoch)
                 print('\nEpoch ', epoch)
@@ -167,7 +199,11 @@ class DINN:
                 print('---------------------------------')
                 print('---------------------------------')
                 for parameter in self.parameters_tilda.items():
                 for parameter in self.parameters_tilda.items():
                     print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
                     print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
-                print('#################################') 
+                print('#################################')
+
+        # create prediction animation
+        if create_animation:
+            self.plotter.animate(self.data.name + '_animation')
 
 
     def plot_training_graphs(self, ground_truth=[]):
     def plot_training_graphs(self, ground_truth=[]):
         """Plot the loss graph and the graphs of the advancements of the parameters.
         """Plot the loss graph and the graphs of the advancements of the parameters.
@@ -178,57 +214,14 @@ class DINN:
         assert self.epochs != None
         assert self.epochs != None
         epochs = np.arange(0, self.epochs, 1)
         epochs = np.arange(0, self.epochs, 1)
 
 
-        rcParams['font.family'] = 'Comfortaa'
-        rcParams['font.size'] = 12
-
-        rcParams['text.color'] = FONT_COLOR
-        rcParams['axes.labelcolor'] = FONT_COLOR
-        rcParams['xtick.color'] = FONT_COLOR
-        rcParams['ytick.color'] = FONT_COLOR
-
         # plot loss
         # plot loss
-        plt.plot(epochs, self.losses)
-        plt.title('Loss')
-        plt.yscale('log')
-        plt.show()
-
+        self.plotter.plot(epochs, [self.losses], ['loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=False, xlabel='epochs')
+        
         # plot parameters
         # plot parameters
         for i, parameter in enumerate(self.parameters):
         for i, parameter in enumerate(self.parameters):
-            figure = plt.figure(figsize=(6,6))
-            ax = figure.add_subplot(111, facecolor='#dddddd', axisbelow=True)
-            ax.set_facecolor('xkcd:white')
-
-            ax.plot(epochs, parameter, c=FONT_COLOR, lw=3, label='prediction')
             if len(ground_truth) > i:
             if len(ground_truth) > i:
-                ax.axhline(y=ground_truth[i], color=INFECTIOUS, linestyle='-', lw=3, label='ground truth')
-        
-            plt.xlabel('epochs')
-            plt.title(list(self.parameters_tilda.items())[i][0])
-            ax.yaxis.set_tick_params(length=0)
-
-            ax.yaxis.set_tick_params(length=0, which='both')
-            ax.xaxis.set_tick_params(length=0, which='both')
-            ax.grid(which='major', c='black', lw=0.2, ls='-')
-
-            plt.legend()
-
-            for spine in ('top', 'right', 'bottom', 'left'):
-                ax.spines[spine].set_visible(False)
-
-            figure.savefig(f'visualizations/{list(self.parameters_tilda.items())[i][0]}.png', transparent=True)
-            plt.show()
-
-    def to_cuda(self):
-        """Move training to cuda device
-        """
-        assert torch.cuda.is_available(), "CUDA is not available"
-        
-        self.device = torch.device('cuda')
-        self.model = self.model.to(self.device)
-        self.data.to_device(self.device)
-        self.problem.to_device(self.device)
-
-        for parameter in self.parameters_tilda:
-            self.parameters_tilda[parameter] = self.parameters_tilda[parameter].to(self.device).detach().requires_grad_()
-        
+                self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
+            else:
+                self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
 
 
+