|
@@ -1,10 +1,33 @@
|
|
|
import torch
|
|
|
+import csv
|
|
|
import numpy as np
|
|
|
|
|
|
+from enum import Enum
|
|
|
+
|
|
|
from .dataset import PandemicDataset
|
|
|
from .problem import PandemicProblem
|
|
|
from .plotter import Plotter
|
|
|
|
|
|
+class Optimizer(Enum):
|
|
|
+ ADAM=0
|
|
|
+
|
|
|
+class Scheduler(Enum):
|
|
|
+ CYCLIC=0
|
|
|
+ CONSTANT=1
|
|
|
+ LINEAR=2
|
|
|
+ POLYNOMIAL=3
|
|
|
+
|
|
|
+class Activation(Enum):
|
|
|
+ LINEAR=0
|
|
|
+ POWER=1
|
|
|
+
|
|
|
+def linear(x):
|
|
|
+ return x
|
|
|
+
|
|
|
+def power(x):
|
|
|
+ return torch.float_power(x, 2)
|
|
|
+
|
|
|
+
|
|
|
class DINN:
|
|
|
class NN(torch.nn.Module):
|
|
|
def __init__(self,
|
|
@@ -12,9 +35,12 @@ class DINN:
|
|
|
input_size: int,
|
|
|
hidden_size: int,
|
|
|
hidden_layers: int,
|
|
|
- activation_layer,
|
|
|
+ activation_layer,
|
|
|
t_init,
|
|
|
- t_final) -> None:
|
|
|
+ t_final,
|
|
|
+ output_activation_function=Activation.LINEAR,
|
|
|
+ use_glorot_initialization = False,
|
|
|
+ use_t_scaled=True) -> None:
|
|
|
"""Neural Network
|
|
|
|
|
|
Args:
|
|
@@ -26,21 +52,39 @@ class DINN:
|
|
|
"""
|
|
|
super(DINN.NN, self).__init__()
|
|
|
|
|
|
+ if output_activation_function == Activation.LINEAR:
|
|
|
+ self.out_activation = linear
|
|
|
+ elif output_activation_function == Activation.POWER:
|
|
|
+ self.out_activation = power
|
|
|
+ else:
|
|
|
+ print('Set output activation to default: linear')
|
|
|
+ self.out_activation = self.linear
|
|
|
+
|
|
|
self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
|
|
|
self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
|
|
|
self.output = torch.nn.Linear(hidden_size, output_size)
|
|
|
+
|
|
|
+ if use_glorot_initialization:
|
|
|
+ torch.nn.init.xavier_uniform_(self.input[0].weight)
|
|
|
+ for i in range(hidden_layers):
|
|
|
+ torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
|
|
|
+ torch.nn.init.xavier_uniform_(self.output.weight)
|
|
|
|
|
|
self.__t_init = t_init
|
|
|
self.__t_final = t_final
|
|
|
+ self.__use_t_scaled = use_t_scaled
|
|
|
|
|
|
def forward(self, t):
|
|
|
# normalize input
|
|
|
- t_scaled = (t - self.__t_init) / (self.__t_final - self.__t_init)
|
|
|
- x = self.input(t_scaled)
|
|
|
+ if self.__use_t_scaled:
|
|
|
+ t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
|
|
|
+ else:
|
|
|
+ t_forward = t
|
|
|
+ x = self.input(t_forward)
|
|
|
x = self.hidden(x)
|
|
|
x = self.output(x)
|
|
|
- return x
|
|
|
-
|
|
|
+ return self.out_activation(x)
|
|
|
+
|
|
|
def __init__(self,
|
|
|
output_size: int,
|
|
|
data: PandemicDataset,
|
|
@@ -52,7 +96,9 @@ class DINN:
|
|
|
input_size=1,
|
|
|
hidden_size=20,
|
|
|
hidden_layers=7,
|
|
|
- activation_layer=torch.nn.ReLU()) -> None:
|
|
|
+ activation_layer=torch.nn.ReLU(),
|
|
|
+ activation_output=Activation.LINEAR,
|
|
|
+ use_glorot_initialization = False) -> None:
|
|
|
"""Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the
|
|
|
parameters of a specific mathematical model.
|
|
|
|
|
@@ -78,9 +124,12 @@ class DINN:
|
|
|
input_size,
|
|
|
hidden_size,
|
|
|
hidden_layers,
|
|
|
- activation_layer,
|
|
|
+ activation_layer,
|
|
|
data.t_init,
|
|
|
- data.t_final)
|
|
|
+ data.t_final,
|
|
|
+ activation_output,
|
|
|
+ use_glorot_initialization=use_glorot_initialization,
|
|
|
+ use_t_scaled=data.use_scaled_time)
|
|
|
self.model = self.model.to(self.device)
|
|
|
self.data = data
|
|
|
self.parameter_regulator = parameter_regulator
|
|
@@ -131,8 +180,21 @@ class DINN:
|
|
|
list: list of regulated parameters
|
|
|
"""
|
|
|
return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
|
|
|
+
|
|
|
|
|
|
- def configure_training(self, lr:float, epochs:int, optimizer_name='Adam', scheduler_name='CyclicLR', scheduler_factor = 1, verbose=False):
|
|
|
+ def get_output(self, index):
|
|
|
+ output = self.model(self.data.t_batch)
|
|
|
+ return output[:, index]
|
|
|
+
|
|
|
+ def configure_training(self,
|
|
|
+ lr:float,
|
|
|
+ epochs:int,
|
|
|
+ optimizer_class=Optimizer.ADAM,
|
|
|
+ scheduler_class=Scheduler.CYCLIC,
|
|
|
+ scheduler_factor = 1,
|
|
|
+ lambda_obs = 1,
|
|
|
+ lambda_physics = 1,
|
|
|
+ verbose=False):
|
|
|
"""This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
|
|
|
|
|
|
Args:
|
|
@@ -144,36 +206,38 @@ class DINN:
|
|
|
"""
|
|
|
parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
|
|
|
self.epochs = epochs
|
|
|
- match optimizer_name:
|
|
|
- case 'Adam':
|
|
|
+ self.lambda_obs = lambda_obs
|
|
|
+ self.lambda_physics = lambda_physics
|
|
|
+ match optimizer_class:
|
|
|
+ case Optimizer.ADAM:
|
|
|
self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
|
|
|
case _:
|
|
|
self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
|
|
|
if verbose:
|
|
|
print('---------------------------------')
|
|
|
- print(f' Entered unknown optimizer name: {optimizer_name}\n Defaulted to Adam.')
|
|
|
+ print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
|
|
|
print('---------------------------------')
|
|
|
- optimizer_name = 'Adam'
|
|
|
+ optimizer_class = Optimizer.ADAM
|
|
|
|
|
|
- match scheduler_name:
|
|
|
- case 'CyclicLR':
|
|
|
+ match scheduler_class:
|
|
|
+ case Scheduler.CYCLIC:
|
|
|
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
- case 'ConstantLR':
|
|
|
+ case Scheduler.CONSTANT:
|
|
|
self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
|
|
|
- case 'LinearLR':
|
|
|
+ case Scheduler.LINEAR:
|
|
|
self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
|
|
|
- case 'PolynomialLR':
|
|
|
+ case Scheduler.POLYNOMIAL:
|
|
|
self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
|
|
|
case _:
|
|
|
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
if verbose:
|
|
|
print('---------------------------------')
|
|
|
- print(f' Entered unknown scheduler name: {scheduler_name}\n Defaulted to CyclicLR.')
|
|
|
+ print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
|
|
|
print('---------------------------------')
|
|
|
- scheduler_name = 'CyclicLR'
|
|
|
+ scheduler_class = Scheduler.CYCLIC
|
|
|
|
|
|
if verbose:
|
|
|
- print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_name}\nScheduler:\t{scheduler_name}\n')
|
|
|
+ print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
|
|
|
|
|
|
self.__is_configured = True
|
|
|
|
|
@@ -181,7 +245,9 @@ class DINN:
|
|
|
def train(self,
|
|
|
create_animation=False,
|
|
|
animation_sample_rate=500,
|
|
|
- verbose=False):
|
|
|
+ verbose=False,
|
|
|
+ do_split_training=False,
|
|
|
+ start_split=10000):
|
|
|
"""Training routine for the DINN.
|
|
|
|
|
|
Args:
|
|
@@ -203,20 +269,27 @@ class DINN:
|
|
|
# get the prediction and the fitting residuals
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
|
|
|
-
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
# calculate loss from the differential system
|
|
|
loss_physics = 0
|
|
|
for residual in residuals:
|
|
|
loss_physics += torch.mean(torch.square(residual))
|
|
|
+ loss_physics *= self.lambda_physics
|
|
|
|
|
|
# calculate loss from the dataset
|
|
|
loss_obs = 0
|
|
|
for i, group in enumerate(self.data.group_names):
|
|
|
loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
|
|
|
+ loss_obs *= self.lambda_obs
|
|
|
|
|
|
- loss = loss_obs + loss_physics
|
|
|
+ if do_split_training:
|
|
|
+ if epoch < start_split:
|
|
|
+ loss = loss_obs
|
|
|
+ else:
|
|
|
+ loss = loss_obs + loss_physics
|
|
|
+ else:
|
|
|
+ loss = loss_obs + loss_physics
|
|
|
|
|
|
loss.backward()
|
|
|
self.optimizer.step()
|
|
@@ -291,7 +364,8 @@ class DINN:
|
|
|
np.ones_like(epochs) * ground_truth[i]],
|
|
|
['prediction', 'ground truth'],
|
|
|
self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
- list(self.parameters_tilda.items())[i][0], (6,6),
|
|
|
+ list(self.parameters_tilda.items())[i][0],
|
|
|
+ (6,6),
|
|
|
is_background=[0, 1],
|
|
|
xlabel='epochs')
|
|
|
else:
|
|
@@ -302,19 +376,42 @@ class DINN:
|
|
|
list(self.parameters_tilda.items())[i][0], (6,6),
|
|
|
xlabel='epochs',
|
|
|
plot_legend=False)
|
|
|
+
|
|
|
+ def save_training_process(self, title, save_predictions = True):
|
|
|
+ losses = {'loss' : self.losses,
|
|
|
+ 'obs_loss' : self.obs_losses,
|
|
|
+ 'physics_loss' : self.physics_losses}
|
|
|
+ for loss in losses.keys():
|
|
|
+ with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
|
|
|
+ writer = csv.writer(csvfile, delimiter=',')
|
|
|
+ writer.writerow(losses[loss])
|
|
|
+
|
|
|
+ for i, parameter in enumerate(self.parameters):
|
|
|
+ with open(f'./results/training_metrics/{title}_{list(self.parameters_tilda.items())[i][0]}.csv', 'w', newline='') as csvfile:
|
|
|
+ writer = csv.writer(csvfile, delimiter=',')
|
|
|
+ writer.writerow(parameter)
|
|
|
+ if save_predictions:
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ for i, group in enumerate(self.data.group_names):
|
|
|
+ t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
|
|
|
+ true = self.data.get_group(group).detach().cpu().numpy()
|
|
|
+ pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
|
|
|
+ print(t.shape, true.shape)
|
|
|
+ with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
|
|
|
+ writer = csv.writer(csvfile, delimiter=',')
|
|
|
+ writer.writerow(t)
|
|
|
+ writer.writerow(true)
|
|
|
+ writer.writerow(pred)
|
|
|
|
|
|
def plot_state_variables(self):
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
- groups = [prediction[:, i] for i in range(self.data.number_groups)]
|
|
|
- fore_background = [0] + [1 for _ in groups]
|
|
|
for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
|
|
|
- t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+ t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
|
|
|
self.plotter.plot(t,
|
|
|
- [prediction[:, i]] + groups,
|
|
|
- [self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
|
|
|
+ [prediction[:, i]],
|
|
|
+ [self.__state_variables[i-self.data.number_groups]],
|
|
|
f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
|
|
|
self.__state_variables[i-self.data.number_groups],
|
|
|
- is_background=fore_background,
|
|
|
figure_shape=(12, 6),
|
|
|
plot_legend=True,
|
|
|
xlabel='time / days')
|