|
|
@@ -1,15 +1,15 @@
|
|
|
import torch
|
|
|
+import os
|
|
|
+import imageio
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib import rcParams
|
|
|
|
|
|
from .dataset import PandemicDataset
|
|
|
from .problem import PandemicProblem
|
|
|
+from .plotter import Plotter
|
|
|
+
|
|
|
|
|
|
-FONT_COLOR = '#595959'
|
|
|
-SUSCEPTIBLE = '#6399f7'
|
|
|
-INFECTIOUS = '#f56262'
|
|
|
-REMOVED = '#83eb5e'
|
|
|
|
|
|
class DINN:
|
|
|
class NN(torch.nn.Module):
|
|
|
@@ -45,6 +45,7 @@ class DINN:
|
|
|
data: PandemicDataset,
|
|
|
parameter_list: list,
|
|
|
problem: PandemicProblem,
|
|
|
+ plotter: Plotter,
|
|
|
parameter_regulator=torch.tanh,
|
|
|
input_size=1,
|
|
|
hidden_size=20,
|
|
|
@@ -65,22 +66,28 @@ class DINN:
|
|
|
activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
|
|
|
"""
|
|
|
|
|
|
- self.device = torch.device('cpu')
|
|
|
+ self.device = torch.device(data.device_name)
|
|
|
+ self.device_name = data.device_name
|
|
|
+ self.plotter = plotter
|
|
|
|
|
|
self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer)
|
|
|
+ self.model = self.model.to(self.device)
|
|
|
self.data = data
|
|
|
self.parameter_regulator = parameter_regulator
|
|
|
self.problem = problem
|
|
|
|
|
|
self.parameters_tilda = {}
|
|
|
for parameter in parameter_list:
|
|
|
- self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True))})
|
|
|
+ self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
|
|
|
self.epochs = None
|
|
|
|
|
|
self.losses = np.zeros(1)
|
|
|
self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
|
|
|
|
|
|
+ self.frames = []
|
|
|
+
|
|
|
+
|
|
|
def get_regulated_param(self, parameter_name: str):
|
|
|
"""Function to get the searched parameters, forced into a certain range.
|
|
|
|
|
|
@@ -111,13 +118,17 @@ class DINN:
|
|
|
def train(self,
|
|
|
epochs: int,
|
|
|
lr: float,
|
|
|
- optimizer_class=torch.optim.Adam):
|
|
|
+ optimizer_class=torch.optim.Adam,
|
|
|
+ create_animation=False,
|
|
|
+ animation_sample_rate=500):
|
|
|
"""Training routine for the DINN
|
|
|
|
|
|
Args:
|
|
|
epochs (int): Number of epochs the NN is supposed to be trained for.
|
|
|
lr (float): Learning rate for the optimizer.
|
|
|
optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
|
|
|
+ create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
|
|
|
+ animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
|
|
|
"""
|
|
|
|
|
|
# define optimizer and scheduler
|
|
|
@@ -133,7 +144,7 @@ class DINN:
|
|
|
for epoch in range(epochs):
|
|
|
# get the prediction and the fitting residuals
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
- residuals = self.problem.residual(prediction, self.data, *self.get_regulated_param_list())
|
|
|
+ residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
@@ -158,6 +169,27 @@ class DINN:
|
|
|
for i, parameter in enumerate(self.parameters_tilda.items()):
|
|
|
self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
|
|
|
|
|
|
+ # do snapshot for prediction animation
|
|
|
+ if epoch % animation_sample_rate == 0 and create_animation:
|
|
|
+ # prediction
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+
|
|
|
+ groups = self.problem.denormalization(prediction)
|
|
|
+ self.plotter.plot(t,
|
|
|
+ groups + tuple(self.data.get_data()),
|
|
|
+ [name + '_pred' for name in self.data.get_group_names()] + [name + '_true' for name in self.data.get_group_names()],
|
|
|
+ 'frame',
|
|
|
+ f'epoch {epoch}',
|
|
|
+ figure_shape=(12, 6),
|
|
|
+ is_frame=True,
|
|
|
+ is_background=[0, 0, 0, 1, 1, 1],
|
|
|
+ lw=3,
|
|
|
+ legend_loc='upper right',
|
|
|
+ ylim=(0, self.data.N),
|
|
|
+ xlabel='time / days',
|
|
|
+ ylabel='amount of people')
|
|
|
+
|
|
|
# print training advancements
|
|
|
if epoch % 1000 == 0:
|
|
|
print('\nEpoch ', epoch)
|
|
|
@@ -167,7 +199,11 @@ class DINN:
|
|
|
print('---------------------------------')
|
|
|
for parameter in self.parameters_tilda.items():
|
|
|
print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
|
|
|
- print('#################################')
|
|
|
+ print('#################################')
|
|
|
+
|
|
|
+ # create prediction animation
|
|
|
+ if create_animation:
|
|
|
+ self.plotter.animate(self.data.name + '_animation')
|
|
|
|
|
|
def plot_training_graphs(self, ground_truth=[]):
|
|
|
"""Plot the loss graph and the graphs of the advancements of the parameters.
|
|
|
@@ -178,57 +214,14 @@ class DINN:
|
|
|
assert self.epochs != None
|
|
|
epochs = np.arange(0, self.epochs, 1)
|
|
|
|
|
|
- rcParams['font.family'] = 'Comfortaa'
|
|
|
- rcParams['font.size'] = 12
|
|
|
-
|
|
|
- rcParams['text.color'] = FONT_COLOR
|
|
|
- rcParams['axes.labelcolor'] = FONT_COLOR
|
|
|
- rcParams['xtick.color'] = FONT_COLOR
|
|
|
- rcParams['ytick.color'] = FONT_COLOR
|
|
|
-
|
|
|
# plot loss
|
|
|
- plt.plot(epochs, self.losses)
|
|
|
- plt.title('Loss')
|
|
|
- plt.yscale('log')
|
|
|
- plt.show()
|
|
|
-
|
|
|
+ self.plotter.plot(epochs, [self.losses], ['loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=False, xlabel='epochs')
|
|
|
+
|
|
|
# plot parameters
|
|
|
for i, parameter in enumerate(self.parameters):
|
|
|
- figure = plt.figure(figsize=(6,6))
|
|
|
- ax = figure.add_subplot(111, facecolor='#dddddd', axisbelow=True)
|
|
|
- ax.set_facecolor('xkcd:white')
|
|
|
-
|
|
|
- ax.plot(epochs, parameter, c=FONT_COLOR, lw=3, label='prediction')
|
|
|
if len(ground_truth) > i:
|
|
|
- ax.axhline(y=ground_truth[i], color=INFECTIOUS, linestyle='-', lw=3, label='ground truth')
|
|
|
-
|
|
|
- plt.xlabel('epochs')
|
|
|
- plt.title(list(self.parameters_tilda.items())[i][0])
|
|
|
- ax.yaxis.set_tick_params(length=0)
|
|
|
-
|
|
|
- ax.yaxis.set_tick_params(length=0, which='both')
|
|
|
- ax.xaxis.set_tick_params(length=0, which='both')
|
|
|
- ax.grid(which='major', c='black', lw=0.2, ls='-')
|
|
|
-
|
|
|
- plt.legend()
|
|
|
-
|
|
|
- for spine in ('top', 'right', 'bottom', 'left'):
|
|
|
- ax.spines[spine].set_visible(False)
|
|
|
-
|
|
|
- figure.savefig(f'visualizations/{list(self.parameters_tilda.items())[i][0]}.png', transparent=True)
|
|
|
- plt.show()
|
|
|
-
|
|
|
- def to_cuda(self):
|
|
|
- """Move training to cuda device
|
|
|
- """
|
|
|
- assert torch.cuda.is_available(), "CUDA is not available"
|
|
|
-
|
|
|
- self.device = torch.device('cuda')
|
|
|
- self.model = self.model.to(self.device)
|
|
|
- self.data.to_device(self.device)
|
|
|
- self.problem.to_device(self.device)
|
|
|
-
|
|
|
- for parameter in self.parameters_tilda:
|
|
|
- self.parameters_tilda[parameter] = self.parameters_tilda[parameter].to(self.device).detach().requires_grad_()
|
|
|
-
|
|
|
+ self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
|
|
|
+ else:
|
|
|
+ self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
|
|
|
|
|
|
+
|