123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- import torch
- import csv
- import numpy as np
- from enum import Enum
- from .dataset import PandemicDataset
- from .problem import PandemicProblem
- from .plotter import Plotter
- class Optimizer(Enum):
- ADAM = 0
- class Scheduler(Enum):
- CYCLIC = 0
- CONSTANT = 1
- LINEAR = 2
- POLYNOMIAL = 3
- class Activation(Enum):
- LINEAR = 0
- POWER = 1
- def linear(x):
- return x
- def power(x):
- return torch.float_power(x, 2)
- class DINN:
- class NN(torch.nn.Module):
- def __init__(self,
- output_size: int,
- input_size: int,
- hidden_size: int,
- hidden_layers: int,
- activation_layer,
- t_init,
- t_final,
- output_activation_function=Activation.LINEAR,
- use_glorot_initialization=False,
- use_t_scaled=True) -> None:
- """Neural Network
- Args:
- output_size (int): number of outputs
- input_size (int): number of inputs
- hidden_size (int): number of hidden nodes per layer
- hidden_layers (int): number of hidden layers
- activation_layer (_type_): activation layer
- """
- super(DINN.NN, self).__init__()
- if output_activation_function == Activation.LINEAR:
- self.out_activation = linear
- elif output_activation_function == Activation.POWER:
- self.out_activation = power
- else:
- print('Set output activation to default: linear')
- self.out_activation = self.linear
- self.input = torch.nn.Sequential(torch.nn.Linear(
- input_size, hidden_size), activation_layer)
- self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(
- hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
- self.output = torch.nn.Linear(hidden_size, output_size)
- if use_glorot_initialization:
- torch.nn.init.xavier_uniform_(self.input[0].weight)
- for i in range(hidden_layers):
- torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
- torch.nn.init.xavier_uniform_(self.output.weight)
- self.__t_init = t_init
- self.__t_final = t_final
- self.__use_t_scaled = use_t_scaled
- def forward(self, t):
- # normalize input
- if self.__use_t_scaled:
- t_forward = (t - self.__t_init) / \
- (self.__t_final - self.__t_init)
- else:
- t_forward = t
- x = self.input(t_forward)
- x = self.hidden(x)
- x = self.output(x)
- return self.out_activation(x)
- def __init__(self,
- output_size: int,
- data: PandemicDataset,
- parameter_list: list,
- problem: PandemicProblem,
- plotter: Plotter,
- state_variables=[],
- parameter_regulator=torch.tanh,
- input_size=1,
- hidden_size=20,
- hidden_layers=7,
- activation_layer=torch.nn.ReLU(),
- activation_output=Activation.LINEAR,
- use_glorot_initialization=False) -> None:
- """Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the
- parameters of a specific mathematical model.
- Args:
- output_size (int): Number of the output nodes of the NN.
- data (PandemicDataset): Data collected showing the course of the pandemic
- parameter_list (list): List of the parameter names(strings), that are supposed to be found.
- problem (PandemicProblem): Problem class implementing the calculation of the residuals.
- plotter (Plotter): Plotter object to plot dataset curves.
- state_variables (list, optional): List of the names of state variables. Defaults to [].
- parameter_regulator (optional): Function to force the parameters to be in a certain range. Defaults to torch.tanh.
- input_size (int, optional): Number of the input nodes of the NN. Defaults to 1.
- hidden_size (int, optional): Number of the hidden nodes of the NN. Defaults to 20.
- hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
- activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
- """
- assert len(state_variables) + \
- data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
- self.device = torch.device(data.device_name)
- self.device_name = data.device_name
- self.plotter = plotter
- self.model = DINN.NN(output_size,
- input_size,
- hidden_size,
- hidden_layers,
- activation_layer,
- data.t_init,
- data.t_final,
- activation_output,
- use_glorot_initialization=use_glorot_initialization,
- use_t_scaled=data.use_scaled_time)
- self.model = self.model.to(self.device)
- self.data = data
- self.parameter_regulator = parameter_regulator
- self.problem = problem
- self.problem.def_grad_matrix(output_size)
- self.parameters_tilda = {}
- for parameter in parameter_list:
- self.parameters_tilda.update({parameter: torch.nn.Parameter(
- torch.rand(1, requires_grad=True, device=self.device_name))})
- # new model has to be configured and then trained
- self.__is_configured = False
- self.__has_trained = False
- self.__state_variables = state_variables
- self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
- self.frames = []
- @property
- def number_state_variables(self):
- return len(self.__state_variables)
- def get_regulated_param(self, parameter_name: str):
- """Function to get the searched parameters, forced into a certain range.
- Args:
- parameter_name (str): Name of the parameter to be returned.
- Returns:
- torch.Parameter: Regulated parameter object of the search parameter.
- """
- return self.parameter_regulator(self.parameters_tilda[parameter_name])
- def get_parameters_tilda(self):
- """Function to get the original value (not forced into any range).
- Returns:
- torch.Parameter: Parameter object of the search parameter.
- """
- return list(self.parameters_tilda.values())
- def get_regulated_param_list(self):
- """Get the list of regulated parameters (forced into a specific range).
- Returns:
- list: list of regulated parameters
- """
- return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
- def get_output(self, index):
- output = self.model(self.data.t_batch)
- return output[:, index]
- def configure_training(self,
- lr: float,
- epochs: int,
- optimizer_class=Optimizer.ADAM,
- scheduler_class=Scheduler.CYCLIC,
- scheduler_factor=1,
- lambda_obs=1,
- lambda_physics=1,
- verbose=False):
- """This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
- Args:
- lr (float): Learning rate for the optimizer.
- epochs (int): Number of epochs the NN is supposed to be trained for.
- optimizer_name (str, optional): Name of the optimizer class that is supposed to be used. Defaults to 'Adam'.
- scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
- verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
- """
- parameter_list = list(self.model.parameters()) + \
- list(self.parameters_tilda.values())
- self.epochs = epochs
- self.lambda_obs = lambda_obs
- self.lambda_physics = lambda_physics
- match optimizer_class:
- case Optimizer.ADAM:
- self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
- case _:
- self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
- if verbose:
- print('---------------------------------')
- print(
- f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
- print('---------------------------------')
- optimizer_class = Optimizer.ADAM
- match scheduler_class:
- case Scheduler.CYCLIC:
- self.scheduler = torch.optim.lr_scheduler.CyclicLR(
- self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
- case Scheduler.CONSTANT:
- self.scheduler = torch.optim.lr_scheduler.ConstantLR(
- self.optimizer, factor=1, total_iters=4)
- case Scheduler.LINEAR:
- self.scheduler = torch.optim.lr_scheduler.LinearLR(
- self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
- case Scheduler.POLYNOMIAL:
- self.scheduler = torch.optim.lr_scheduler.PolynomialLR(
- self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
- case _:
- self.scheduler = torch.optim.lr_scheduler.CyclicLR(
- self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
- if verbose:
- print('---------------------------------')
- print(
- f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
- print('---------------------------------')
- scheduler_class = Scheduler.CYCLIC
- if verbose:
- print(
- f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
- self.__is_configured = True
- def train(self,
- create_animation=False,
- animation_sample_rate=500,
- verbose=False,
- do_split_training=False,
- start_split=10000):
- """Training routine for the DINN.
- Args:
- create_animation (boolean, optional): Decides on wether a prediction animation is supposed to be created during training. Defaults to False.
- animation_sample_rate (int, optional): Sample rate of the prediction animation. Only used, when create_animation=True. Defaults to 500.
- verbose (bool, optional): Controles if the training process, is to be verbosed. Defaults to False.
- """
- assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
- if verbose:
- print(f'torch seed: {torch.seed()}')
- # arrays to hold values for plotting
- self.losses = np.zeros(self.epochs)
- self.obs_losses = np.zeros(self.epochs)
- self.physics_losses = np.zeros(self.epochs)
- self.parameters = [np.zeros(self.epochs) for _ in self.parameters]
- for epoch in range(self.epochs):
- # get the prediction and the fitting residuals
- prediction = self.model(self.data.t_batch)
- residuals = self.problem.residual(
- prediction, *self.get_regulated_param_list())
- self.optimizer.zero_grad()
- # calculate loss from the differential system
- loss_physics = 0
- for residual in residuals:
- loss_physics += torch.mean(torch.square(residual))
- loss_physics *= self.lambda_physics
- # calculate loss from the dataset
- loss_obs = 0
- for i, group in enumerate(self.data.group_names):
- loss_obs += torch.mean(torch.square(
- self.data.get_norm(group) - prediction[:, i]))
- loss_obs *= self.lambda_obs
- if do_split_training:
- if epoch < start_split:
- loss = loss_obs
- else:
- loss = loss_obs + loss_physics
- else:
- loss = loss_obs + loss_physics
- loss.backward()
- self.optimizer.step()
- self.scheduler.step()
- # append values for plotting
- self.losses[epoch] = loss.item()
- self.obs_losses[epoch] = loss_obs.item()
- self.physics_losses[epoch] = loss_physics.item()
- for i, parameter in enumerate(self.parameters_tilda.items()):
- self.parameters[i][epoch] = self.get_regulated_param(
- parameter[0]).item()
- # do snapshot for prediction animation
- if epoch % animation_sample_rate == 0 and create_animation:
- # prediction
- prediction = self.model(self.data.t_batch)
- t = torch.arange(
- 0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
- groups = self.data.get_denormalized_data(
- [prediction[:, i] for i in range(self.data.number_groups)])
- plot_labels = ['I_pred', 'I_true']
- background_list = [0, 1]
- self.plotter.plot(t,
- [groups[1]] + [self.data.data[1]],
- plot_labels,
- 'Frame',
- f'epoch {epoch}',
- figure_shape=(12, 6),
- is_frame=True,
- is_background=background_list,
- lw=3,
- legend_loc='upper right',
- xlabel='time / days',
- ylabel='amount of people')
- # print training advancements
- if epoch % 1000 == 0 and verbose:
- print(
- f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
- print(f'physics loss:\t\t{loss_physics.item()}')
- print(f'observation loss:\t{loss_obs.item()}')
- print(f'loss:\t\t\t{loss.item()}')
- print('---------------------------------')
- if len(self.parameters_tilda.items()) != 0:
- for parameter in self.parameters_tilda.items():
- print(
- f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
- print('#################################')
- # create prediction animation
- if create_animation:
- self.plotter.animate(self.data.name + '_animation')
- self.plotter.reset_animation()
- self.__has_trained = True
- def plot_training_graphs(self, ground_truth=[]):
- """Plot the loss graph and the graphs of the advancements of the parameters.
- Args:
- ground_truth (list): List of the ground truth parameters
- """
- assert self.__has_trained, 'Model has to be trained, before plotting the training graphs'
- epochs = np.arange(0, self.epochs, 1)
- # plot loss
- self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss',
- 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
- # plot parameters
- for i, parameter in enumerate(self.parameters):
- if len(ground_truth) > i:
- self.plotter.plot(epochs,
- [parameter,
- np.ones_like(epochs) * ground_truth[i]],
- ['prediction', 'ground truth'],
- self.data.name + '_' +
- list(self.parameters_tilda.items())[i][0],
- list(self.parameters_tilda.items())[i][0],
- (6, 6),
- is_background=[0, 1],
- xlabel='epochs')
- else:
- self.plotter.plot(epochs,
- [parameter],
- ['prediction'],
- self.data.name + '_' +
- list(self.parameters_tilda.items())[i][0],
- list(self.parameters_tilda.items())[
- i][0], (6, 6),
- xlabel='epochs',
- plot_legend=False)
- def save_training_process(self, title, save_predictions=True):
- losses = {'loss': self.losses,
- 'obs_loss': self.obs_losses,
- 'physics_loss': self.physics_losses}
- for loss in losses.keys():
- with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
- writer = csv.writer(csvfile, delimiter=',')
- writer.writerow(losses[loss])
- for i, parameter in enumerate(self.parameters):
- with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
- writer = csv.writer(csvfile, delimiter=',')
- writer.writerow(parameter)
- if save_predictions:
- prediction = self.model(self.data.t_batch)
- for i, group in enumerate(self.data.group_names):
- t = torch.linspace(
- 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
- true = self.data.get_group(group).detach().cpu().numpy()
- pred = self.data.get_denormalized_data([prediction[:, i]])[
- 0].detach().cpu().numpy()
- print(t.shape, true.shape)
- with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
- writer = csv.writer(csvfile, delimiter=',')
- writer.writerow(t)
- writer.writerow(true)
- writer.writerow(pred)
- def plot_state_variables(self):
- prediction = self.model(self.data.t_batch)
- for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
- t = torch.linspace(
- 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
- self.plotter.plot(t,
- [prediction[:, i]],
- [self.__state_variables[i - self.data.number_groups]],
- f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
- self.__state_variables[i -
- self.data.number_groups],
- figure_shape=(12, 6),
- plot_legend=True,
- xlabel='time / days')
|