|
|
@@ -1,16 +1,10 @@
|
|
|
import torch
|
|
|
-import os
|
|
|
-import imageio
|
|
|
import numpy as np
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from matplotlib import rcParams
|
|
|
|
|
|
from .dataset import PandemicDataset
|
|
|
from .problem import PandemicProblem
|
|
|
from .plotter import Plotter
|
|
|
|
|
|
-
|
|
|
-
|
|
|
class DINN:
|
|
|
class NN(torch.nn.Module):
|
|
|
def __init__(self,
|
|
|
@@ -48,55 +42,68 @@ class DINN:
|
|
|
return x
|
|
|
|
|
|
def __init__(self,
|
|
|
- number_groups: int,
|
|
|
+ output_size: int,
|
|
|
data: PandemicDataset,
|
|
|
parameter_list: list,
|
|
|
problem: PandemicProblem,
|
|
|
plotter: Plotter,
|
|
|
+ state_variables=[],
|
|
|
parameter_regulator=torch.tanh,
|
|
|
- input_size=1,
|
|
|
- hidden_size=20,
|
|
|
+ input_size=1,
|
|
|
+ hidden_size=20,
|
|
|
hidden_layers=7,
|
|
|
activation_layer=torch.nn.ReLU()) -> None:
|
|
|
"""Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the
|
|
|
parameters of a specific mathematical model.
|
|
|
|
|
|
Args:
|
|
|
- number_groups (int): The number of groups, that the population is split into.
|
|
|
+ output_size (int): Number of the output nodes of the NN.
|
|
|
data (PandemicDataset): Data collected showing the course of the pandemic
|
|
|
parameter_list (list): List of the parameter names(strings), that are supposed to be found.
|
|
|
problem (PandemicProblem): Problem class implementing the calculation of the residuals.
|
|
|
plotter (Plotter): Plotter object to plot dataset curves.
|
|
|
+ state_variables (list, optional): List of the names of state variables. Defaults to [].
|
|
|
parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
|
|
|
input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
|
|
|
hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
|
|
|
hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
|
|
|
activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
|
|
|
"""
|
|
|
-
|
|
|
+ assert len(state_variables) + data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
|
|
|
self.device = torch.device(data.device_name)
|
|
|
self.device_name = data.device_name
|
|
|
self.plotter = plotter
|
|
|
|
|
|
- self.model = DINN.NN(number_groups, input_size, hidden_size, hidden_layers, activation_layer, data.t_init, data.t_final)
|
|
|
+ self.model = DINN.NN(output_size,
|
|
|
+ input_size,
|
|
|
+ hidden_size,
|
|
|
+ hidden_layers,
|
|
|
+ activation_layer,
|
|
|
+ data.t_init,
|
|
|
+ data.t_final)
|
|
|
self.model = self.model.to(self.device)
|
|
|
self.data = data
|
|
|
self.parameter_regulator = parameter_regulator
|
|
|
self.problem = problem
|
|
|
+ self.problem.def_grad_matrix(output_size)
|
|
|
|
|
|
self.parameters_tilda = {}
|
|
|
for parameter in parameter_list:
|
|
|
self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
-
|
|
|
- self.epochs = None
|
|
|
|
|
|
- self.losses = np.zeros(1)
|
|
|
- self.obs_losses = np.zeros(1)
|
|
|
- self.physics_losses = np.zeros(1)
|
|
|
+ # new model has to be configured and then trained
|
|
|
+ self.__is_configured = False
|
|
|
+ self.__has_trained = False
|
|
|
+
|
|
|
+ self.__state_variables = state_variables
|
|
|
+
|
|
|
self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
|
|
|
|
|
|
self.frames = []
|
|
|
|
|
|
+ @property
|
|
|
+ def number_state_variables(self):
|
|
|
+ return len(self.__state_variables)
|
|
|
|
|
|
def get_regulated_param(self, parameter_name: str):
|
|
|
"""Function to get the searched parameters, forced into a certain range.
|
|
|
@@ -125,40 +132,77 @@ class DINN:
|
|
|
"""
|
|
|
return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
|
|
|
|
|
|
+ def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
|
|
|
+ """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ lr (float): Learning rate for the optimizer.
|
|
|
+ epochs (int): Number of epochs the NN is supposed to be trained for.
|
|
|
+ optimizer_name (str, optional): Name of the optimizer class that is supposed to be used. Defaults to 'Adam'.
|
|
|
+ scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
|
|
|
+ verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
|
|
|
+ """
|
|
|
+ parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
|
|
|
+ self.epochs = epochs
|
|
|
+ match optimizer_name:
|
|
|
+ case 'Adam':
|
|
|
+ self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
|
|
|
+ case _:
|
|
|
+ self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
|
|
|
+ if verbose:
|
|
|
+ print('---------------------------------')
|
|
|
+ print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
|
|
|
+ print('---------------------------------')
|
|
|
+ optimizer_name = 'Adam'
|
|
|
+
|
|
|
+ match scheduler_name:
|
|
|
+ case 'CyclicLR':
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+ case 'LinearLR':
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
|
|
|
+ case 'PolynomialLR':
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
|
|
|
+ case _:
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+ if verbose:
|
|
|
+ print('---------------------------------')
|
|
|
+ print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
|
|
|
+ print('---------------------------------')
|
|
|
+ scheduler_name = 'CyclicLR'
|
|
|
+
|
|
|
+ if verbose:
|
|
|
+ print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
|
|
|
+
|
|
|
+ self.__is_configured = True
|
|
|
+
|
|
|
+
|
|
|
def train(self,
|
|
|
- epochs: int,
|
|
|
- lr: float,
|
|
|
- optimizer_class=torch.optim.Adam,
|
|
|
create_animation=False,
|
|
|
- animation_sample_rate=500):
|
|
|
- """Training routine for the DINN
|
|
|
+ animation_sample_rate=500,
|
|
|
+ verbose=False):
|
|
|
+ """Training routine for the DINN.
|
|
|
|
|
|
Args:
|
|
|
- epochs (int): Number of epochs the NN is supposed to be trained for.
|
|
|
- lr (float): Learning rate for the optimizer.
|
|
|
- optimizer_class (optional): Class of the optimizer. Defaults to torch.optim.Adam.
|
|
|
create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
|
|
|
animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
|
|
|
+ verbose (bool, optional): Controles if the training process, is to be verbosed. Defaults to False.
|
|
|
"""
|
|
|
-
|
|
|
- # define optimizer and scheduler
|
|
|
- optimizer = optimizer_class(list(self.model.parameters()) + list(self.parameters_tilda.values()), lr=lr)
|
|
|
- scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
-
|
|
|
- self.epochs = epochs
|
|
|
-
|
|
|
+ assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
|
|
|
+ if verbose:
|
|
|
+ print(f'torch seed: {torch.seed()}')
|
|
|
+
|
|
|
# arrays to hold values for plotting
|
|
|
- self.losses = np.zeros(epochs)
|
|
|
- self.obs_losses = np.zeros(epochs)
|
|
|
- self.physics_losses = np.zeros(epochs)
|
|
|
- self.parameters = [np.zeros(epochs) for _ in self.parameters]
|
|
|
+ self.losses = np.zeros(self.epochs)
|
|
|
+ self.obs_losses = np.zeros(self.epochs)
|
|
|
+ self.physics_losses = np.zeros(self.epochs)
|
|
|
+ self.parameters = [np.zeros(self.epochs) for _ in self.parameters]
|
|
|
|
|
|
- for epoch in range(epochs):
|
|
|
+ for epoch in range(self.epochs):
|
|
|
# get the prediction and the fitting residuals
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
|
|
|
|
|
|
- optimizer.zero_grad()
|
|
|
+ self.optimizer.zero_grad()
|
|
|
|
|
|
# calculate loss from the differential system
|
|
|
loss_physics = 0
|
|
|
@@ -170,11 +214,11 @@ class DINN:
|
|
|
for i, group in enumerate(self.data.group_names):
|
|
|
loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
|
|
|
|
|
|
- loss = loss_physics + loss_obs
|
|
|
+ loss = loss_obs + loss_physics
|
|
|
|
|
|
loss.backward()
|
|
|
- optimizer.step()
|
|
|
- scheduler.step()
|
|
|
+ self.optimizer.step()
|
|
|
+ self.scheduler.step()
|
|
|
|
|
|
# append values for plotting
|
|
|
self.losses[epoch] = loss.item()
|
|
|
@@ -188,15 +232,18 @@ class DINN:
|
|
|
# prediction
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
- groups = self.data.get_denormalized_data([prediction[:, 0], prediction[:, 1], prediction[:, 2]])
|
|
|
+ groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
+
|
|
|
+ plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
|
|
|
+ background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
|
|
|
self.plotter.plot(t,
|
|
|
list(groups) + list(self.data.data),
|
|
|
- [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names],
|
|
|
+ plot_labels,
|
|
|
'frame',
|
|
|
f'epoch {epoch}',
|
|
|
figure_shape=(12, 6),
|
|
|
is_frame=True,
|
|
|
- is_background=[0, 0, 0, 1, 1, 1],
|
|
|
+ is_background=background_list,
|
|
|
lw=3,
|
|
|
legend_loc='upper right',
|
|
|
ylim=(0, self.data.N),
|
|
|
@@ -204,19 +251,23 @@ class DINN:
|
|
|
ylabel='amount of people')
|
|
|
|
|
|
# print training advancements
|
|
|
- if epoch % 1000 == 0:
|
|
|
- print('\nEpoch ', epoch)
|
|
|
+ if epoch % 1000 == 0 and verbose:
|
|
|
+ print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
|
|
|
print(f'physics loss:\t\t{loss_physics.item()}')
|
|
|
print(f'observation loss:\t{loss_obs.item()}')
|
|
|
print(f'loss:\t\t\t{loss.item()}')
|
|
|
print('---------------------------------')
|
|
|
- for parameter in self.parameters_tilda.items():
|
|
|
- print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
|
|
|
- print('#################################')
|
|
|
+ if len(self.parameters_tilda.items()) != 0:
|
|
|
+ for parameter in self.parameters_tilda.items():
|
|
|
+ print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
|
|
|
+ print('#################################')
|
|
|
|
|
|
# create prediction animation
|
|
|
if create_animation:
|
|
|
self.plotter.animate(self.data.name + '_animation')
|
|
|
+ self.plotter.reset_animation()
|
|
|
+
|
|
|
+ self.__has_trained = True
|
|
|
|
|
|
def plot_training_graphs(self, ground_truth=[]):
|
|
|
"""Plot the loss graph and the graphs of the advancements of the parameters.
|
|
|
@@ -224,7 +275,7 @@ class DINN:
|
|
|
Args:
|
|
|
ground_truth (list): List of the ground truth parameters
|
|
|
"""
|
|
|
- assert self.epochs != None
|
|
|
+ assert self.__has_trained, 'Model has to be trained, before plotting the training graphs'
|
|
|
epochs = np.arange(0, self.epochs, 1)
|
|
|
|
|
|
# plot loss
|
|
|
@@ -233,8 +284,34 @@ class DINN:
|
|
|
# plot parameters
|
|
|
for i, parameter in enumerate(self.parameters):
|
|
|
if len(ground_truth) > i:
|
|
|
- self.plotter.plot(epochs, [parameter, np.ones_like(epochs) * ground_truth[i]], ['prediction', 'ground truth'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), is_background=[0, 1], xlabel='epochs')
|
|
|
+ self.plotter.plot(epochs,
|
|
|
+ [parameter,
|
|
|
+ np.ones_like(epochs) * ground_truth[i]],
|
|
|
+ ['prediction', 'ground truth'],
|
|
|
+ self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
+ list(self.parameters_tilda.items())[i][0], (6,6),
|
|
|
+ is_background=[0, 1],
|
|
|
+ xlabel='epochs')
|
|
|
else:
|
|
|
- self.plotter.plot(epochs, [parameter], ['prediction'], self.data.name + '_' + list(self.parameters_tilda.items())[i][0], list(self.parameters_tilda.items())[i][0], (6,6), xlabel='epochs', plot_legend=False)
|
|
|
-
|
|
|
-
|
|
|
+ self.plotter.plot(epochs,
|
|
|
+ [parameter],
|
|
|
+ ['prediction'],
|
|
|
+ self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
+ list(self.parameters_tilda.items())[i][0], (6,6),
|
|
|
+ xlabel='epochs',
|
|
|
+ plot_legend=False)
|
|
|
+
|
|
|
+ def plot_state_variables(self):
|
|
|
+ for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ groups = [prediction[:, i] for i in range(self.data.number_groups)]
|
|
|
+ t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+ self.plotter.plot(t,
|
|
|
+ [prediction[:, i]] + groups,
|
|
|
+ [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
|
|
|
+ f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
|
|
|
+ self.__state_variables[i-self.data.number_groups],
|
|
|
+ is_background=[0, 1, 1],
|
|
|
+ figure_shape=(12, 6),
|
|
|
+ plot_legend=True,
|
|
|
+ xlabel='time / days')
|