|
|
@@ -8,38 +8,43 @@ from .dataset import PandemicDataset
|
|
|
from .problem import PandemicProblem
|
|
|
from .plotter import Plotter
|
|
|
|
|
|
+
|
|
|
class Optimizer(Enum):
|
|
|
- ADAM=0
|
|
|
+ ADAM = 0
|
|
|
+
|
|
|
|
|
|
class Scheduler(Enum):
|
|
|
- CYCLIC=0
|
|
|
- CONSTANT=1
|
|
|
- LINEAR=2
|
|
|
- POLYNOMIAL=3
|
|
|
+ CYCLIC = 0
|
|
|
+ CONSTANT = 1
|
|
|
+ LINEAR = 2
|
|
|
+ POLYNOMIAL = 3
|
|
|
+
|
|
|
|
|
|
class Activation(Enum):
|
|
|
- LINEAR=0
|
|
|
- POWER=1
|
|
|
+ LINEAR = 0
|
|
|
+ POWER = 1
|
|
|
+
|
|
|
|
|
|
def linear(x):
|
|
|
return x
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
def power(x):
|
|
|
return torch.float_power(x, 2)
|
|
|
|
|
|
|
|
|
class DINN:
|
|
|
class NN(torch.nn.Module):
|
|
|
- def __init__(self,
|
|
|
+ def __init__(self,
|
|
|
output_size: int,
|
|
|
input_size: int,
|
|
|
hidden_size: int,
|
|
|
- hidden_layers: int,
|
|
|
+ hidden_layers: int,
|
|
|
activation_layer,
|
|
|
t_init,
|
|
|
t_final,
|
|
|
output_activation_function=Activation.LINEAR,
|
|
|
- use_glorot_initialization = False,
|
|
|
+ use_glorot_initialization=False,
|
|
|
use_t_scaled=True) -> None:
|
|
|
"""Neural Network
|
|
|
|
|
|
@@ -69,7 +74,7 @@ class DINN:
|
|
|
for i in range(hidden_layers):
|
|
|
torch.nn.init.xavier_uniform_(self.hidden[i][0].weight)
|
|
|
torch.nn.init.xavier_uniform_(self.output.weight)
|
|
|
-
|
|
|
+
|
|
|
self.__t_init = t_init
|
|
|
self.__t_final = t_final
|
|
|
self.__use_t_scaled = use_t_scaled
|
|
|
@@ -84,7 +89,7 @@ class DINN:
|
|
|
x = self.hidden(x)
|
|
|
x = self.output(x)
|
|
|
return self.out_activation(x)
|
|
|
-
|
|
|
+
|
|
|
def __init__(self,
|
|
|
output_size: int,
|
|
|
data: PandemicDataset,
|
|
|
@@ -93,12 +98,12 @@ class DINN:
|
|
|
plotter: Plotter,
|
|
|
state_variables=[],
|
|
|
parameter_regulator=torch.tanh,
|
|
|
- input_size=1,
|
|
|
- hidden_size=20,
|
|
|
- hidden_layers=7,
|
|
|
+ input_size=1,
|
|
|
+ hidden_size=20,
|
|
|
+ hidden_layers=7,
|
|
|
activation_layer=torch.nn.ReLU(),
|
|
|
activation_output=Activation.LINEAR,
|
|
|
- use_glorot_initialization = False) -> None:
|
|
|
+ use_glorot_initialization=False) -> None:
|
|
|
"""Desease Informed Neural Network. Uses the PandemicProblem, DINN.NN and PandemicDataset to solve Inverse Problems and find the
|
|
|
parameters of a specific mathematical model.
|
|
|
|
|
|
@@ -120,12 +125,12 @@ class DINN:
|
|
|
self.device_name = data.device_name
|
|
|
self.plotter = plotter
|
|
|
|
|
|
- self.model = DINN.NN(output_size,
|
|
|
- input_size,
|
|
|
- hidden_size,
|
|
|
- hidden_layers,
|
|
|
+ self.model = DINN.NN(output_size,
|
|
|
+ input_size,
|
|
|
+ hidden_size,
|
|
|
+ hidden_layers,
|
|
|
activation_layer,
|
|
|
- data.t_init,
|
|
|
+ data.t_init,
|
|
|
data.t_final,
|
|
|
activation_output,
|
|
|
use_glorot_initialization=use_glorot_initialization,
|
|
|
@@ -138,7 +143,7 @@ class DINN:
|
|
|
|
|
|
self.parameters_tilda = {}
|
|
|
for parameter in parameter_list:
|
|
|
- self.parameters_tilda.update({parameter : torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
+ self.parameters_tilda.update({parameter: torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
|
|
|
# new model has to be configured and then trained
|
|
|
self.__is_configured = False
|
|
|
@@ -164,7 +169,7 @@ class DINN:
|
|
|
torch.Parameter: Regulated parameter object of the search parameter.
|
|
|
"""
|
|
|
return self.parameter_regulator(self.parameters_tilda[parameter_name])
|
|
|
-
|
|
|
+
|
|
|
def get_parameters_tilda(self):
|
|
|
"""Function to get the original value (not forced into any range).
|
|
|
|
|
|
@@ -181,19 +186,18 @@ class DINN:
|
|
|
"""
|
|
|
return [self.parameter_regulator(parameter) for parameter in self.get_parameters_tilda()]
|
|
|
|
|
|
-
|
|
|
def get_output(self, index):
|
|
|
output = self.model(self.data.t_batch)
|
|
|
return output[:, index]
|
|
|
-
|
|
|
- def configure_training(self,
|
|
|
- lr:float,
|
|
|
- epochs:int,
|
|
|
- optimizer_class=Optimizer.ADAM,
|
|
|
- scheduler_class=Scheduler.CYCLIC,
|
|
|
- scheduler_factor = 1,
|
|
|
- lambda_obs = 1,
|
|
|
- lambda_physics = 1,
|
|
|
+
|
|
|
+ def configure_training(self,
|
|
|
+ lr: float,
|
|
|
+ epochs: int,
|
|
|
+ optimizer_class=Optimizer.ADAM,
|
|
|
+ scheduler_class=Scheduler.CYCLIC,
|
|
|
+ scheduler_factor=1,
|
|
|
+ lambda_obs=1,
|
|
|
+ lambda_physics=1,
|
|
|
verbose=False):
|
|
|
"""This method sets the optimizer, scheduler, learning rate and number of epochs for the following training process.
|
|
|
|
|
|
@@ -225,9 +229,9 @@ class DINN:
|
|
|
case Scheduler.CONSTANT:
|
|
|
self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
|
|
|
case Scheduler.LINEAR:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
|
|
|
case Scheduler.POLYNOMIAL:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs/scheduler_factor, power=1.0)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
|
|
|
case _:
|
|
|
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
if verbose:
|
|
|
@@ -241,8 +245,8 @@ class DINN:
|
|
|
|
|
|
self.__is_configured = True
|
|
|
|
|
|
-
|
|
|
- def train(self,
|
|
|
+ def train(self,
|
|
|
+ plot_I_prediction=False,
|
|
|
create_animation=False,
|
|
|
animation_sample_rate=500,
|
|
|
verbose=False,
|
|
|
@@ -258,7 +262,7 @@ class DINN:
|
|
|
assert self.__is_configured, 'The model has to be configured before training through the use of self.configure training.'
|
|
|
if verbose:
|
|
|
print(f'torch seed: {torch.seed()}')
|
|
|
-
|
|
|
+
|
|
|
# arrays to hold values for plotting
|
|
|
self.losses = np.zeros(self.epochs)
|
|
|
self.obs_losses = np.zeros(self.epochs)
|
|
|
@@ -282,7 +286,7 @@ class DINN:
|
|
|
for i, group in enumerate(self.data.group_names):
|
|
|
loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
|
|
|
loss_obs *= self.lambda_obs
|
|
|
-
|
|
|
+
|
|
|
if do_split_training:
|
|
|
if epoch < start_split:
|
|
|
loss = loss_obs
|
|
|
@@ -305,14 +309,14 @@ class DINN:
|
|
|
# do snapshot for prediction animation
|
|
|
if epoch % animation_sample_rate == 0 and create_animation:
|
|
|
# prediction
|
|
|
- prediction = self.model(self.data.t_batch)
|
|
|
+ """prediction = self.model(self.data.t_batch)
|
|
|
t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
|
|
|
plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
|
|
|
background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
|
|
|
- self.plotter.plot(t,
|
|
|
- list(groups) + list(self.data.data),
|
|
|
+ self.plotter.plot(t,
|
|
|
+ list(groups) + list(self.data.data),
|
|
|
plot_labels,
|
|
|
'frame',
|
|
|
f'epoch {epoch}',
|
|
|
@@ -321,12 +325,30 @@ class DINN:
|
|
|
is_background=background_list,
|
|
|
lw=3,
|
|
|
legend_loc='upper right',
|
|
|
- ylim=(0, self.data.N),
|
|
|
+ ylim=(0, self.data.N),
|
|
|
+ xlabel='time / days',
|
|
|
+ ylabel='amount of people')"""
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+ groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
+
|
|
|
+ plot_labels = ['I_pred', 'I_true']
|
|
|
+ background_list = [0, 1]
|
|
|
+ self.plotter.plot(t,
|
|
|
+ [groups[1]] + [self.data.data[1]],
|
|
|
+ plot_labels,
|
|
|
+ 'Frame',
|
|
|
+ f'epoch {epoch}',
|
|
|
+ figure_shape=(12, 6),
|
|
|
+ is_frame=True,
|
|
|
+ is_background=background_list,
|
|
|
+ lw=3,
|
|
|
+ legend_loc='upper right',
|
|
|
xlabel='time / days',
|
|
|
ylabel='amount of people')
|
|
|
|
|
|
# print training advancements
|
|
|
- if epoch % 1000 == 0 and verbose:
|
|
|
+ if epoch % 1000 == 0 and verbose:
|
|
|
print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
|
|
|
print(f'physics loss:\t\t{loss_physics.item()}')
|
|
|
print(f'observation loss:\t{loss_obs.item()}')
|
|
|
@@ -342,6 +364,25 @@ class DINN:
|
|
|
self.plotter.animate(self.data.name + '_animation')
|
|
|
self.plotter.reset_animation()
|
|
|
|
|
|
+ if plot_I_prediction:
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+ groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
+
|
|
|
+ plot_labels = ['I_pred', 'I_true']
|
|
|
+ background_list = [0, 1]
|
|
|
+ self.plotter.plot(t,
|
|
|
+ [groups[1]] + [self.data.data[1]],
|
|
|
+ plot_labels,
|
|
|
+ 'Training_I_prediction',
|
|
|
+ f'Prediction of I on JH data',
|
|
|
+ figure_shape=(12, 6),
|
|
|
+ is_background=background_list,
|
|
|
+ lw=3,
|
|
|
+ legend_loc='upper right',
|
|
|
+ xlabel='time / days',
|
|
|
+ ylabel='amount of people')
|
|
|
+
|
|
|
self.__has_trained = True
|
|
|
|
|
|
def plot_training_graphs(self, ground_truth=[]):
|
|
|
@@ -355,32 +396,32 @@ class DINN:
|
|
|
|
|
|
# plot loss
|
|
|
self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
|
|
|
-
|
|
|
+
|
|
|
# plot parameters
|
|
|
for i, parameter in enumerate(self.parameters):
|
|
|
if len(ground_truth) > i:
|
|
|
- self.plotter.plot(epochs,
|
|
|
- [parameter,
|
|
|
- np.ones_like(epochs) * ground_truth[i]],
|
|
|
- ['prediction', 'ground truth'],
|
|
|
- self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
- list(self.parameters_tilda.items())[i][0],
|
|
|
- (6,6),
|
|
|
- is_background=[0, 1],
|
|
|
- xlabel='epochs')
|
|
|
+ self.plotter.plot(epochs,
|
|
|
+ [parameter,
|
|
|
+ np.ones_like(epochs) * ground_truth[i]],
|
|
|
+ ['prediction', 'ground truth'],
|
|
|
+ self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
+ list(self.parameters_tilda.items())[i][0],
|
|
|
+ (6, 6),
|
|
|
+ is_background=[0, 1],
|
|
|
+ xlabel='epochs')
|
|
|
else:
|
|
|
- self.plotter.plot(epochs,
|
|
|
- [parameter],
|
|
|
- ['prediction'],
|
|
|
- self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
- list(self.parameters_tilda.items())[i][0], (6,6),
|
|
|
- xlabel='epochs',
|
|
|
+ self.plotter.plot(epochs,
|
|
|
+ [parameter],
|
|
|
+ ['prediction'],
|
|
|
+ self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
+ list(self.parameters_tilda.items())[i][0], (6, 6),
|
|
|
+ xlabel='epochs',
|
|
|
plot_legend=False)
|
|
|
-
|
|
|
- def save_training_process(self, title, save_predictions = True):
|
|
|
- losses = {'loss' : self.losses,
|
|
|
- 'obs_loss' : self.obs_losses,
|
|
|
- 'physics_loss' : self.physics_losses}
|
|
|
+
|
|
|
+ def save_training_process(self, title, save_predictions=True):
|
|
|
+ losses = {'loss': self.losses,
|
|
|
+ 'obs_loss': self.obs_losses,
|
|
|
+ 'physics_loss': self.physics_losses}
|
|
|
for loss in losses.keys():
|
|
|
with open(f'./results/training_metrics/{title}_{loss}.csv', 'w', newline='') as csvfile:
|
|
|
writer = csv.writer(csvfile, delimiter=',')
|
|
|
@@ -405,13 +446,13 @@ class DINN:
|
|
|
|
|
|
def plot_state_variables(self):
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
- for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
|
|
|
+ for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
|
|
|
t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
|
|
|
self.plotter.plot(t,
|
|
|
[prediction[:, i]],
|
|
|
- [self.__state_variables[i-self.data.number_groups]],
|
|
|
+ [self.__state_variables[i - self.data.number_groups]],
|
|
|
f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
|
|
|
- self.__state_variables[i-self.data.number_groups],
|
|
|
+ self.__state_variables[i - self.data.number_groups],
|
|
|
figure_shape=(12, 6),
|
|
|
plot_legend=True,
|
|
|
- xlabel='time / days')
|
|
|
+ xlabel='time / days')
|