|
@@ -65,8 +65,10 @@ class DINN:
|
|
|
print('Set output activation to default: linear')
|
|
|
self.out_activation = self.linear
|
|
|
|
|
|
- self.input = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size), activation_layer)
|
|
|
- self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
|
|
|
+ self.input = torch.nn.Sequential(torch.nn.Linear(
|
|
|
+ input_size, hidden_size), activation_layer)
|
|
|
+ self.hidden = torch.nn.Sequential(*[torch.nn.Sequential(torch.nn.Linear(
|
|
|
+ hidden_size, hidden_size), activation_layer) for _ in range(hidden_layers)])
|
|
|
self.output = torch.nn.Linear(hidden_size, output_size)
|
|
|
|
|
|
if use_glorot_initialization:
|
|
@@ -82,7 +84,8 @@ class DINN:
|
|
|
def forward(self, t):
|
|
|
# normalize input
|
|
|
if self.__use_t_scaled:
|
|
|
- t_forward = (t - self.__t_init) / (self.__t_final - self.__t_init)
|
|
|
+ t_forward = (t - self.__t_init) / \
|
|
|
+ (self.__t_final - self.__t_init)
|
|
|
else:
|
|
|
t_forward = t
|
|
|
x = self.input(t_forward)
|
|
@@ -120,7 +123,8 @@ class DINN:
|
|
|
hidden_layers (int, optional): Number of the hidden layers for the NN. Defaults to 7.
|
|
|
activation_layer (optional): Class of the activation function. Defaults to torch.nn.ReLU().
|
|
|
"""
|
|
|
- assert len(state_variables) + data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
|
|
|
+ assert len(state_variables) + \
|
|
|
+ data.number_groups == output_size, f'The number of groups plus the number of state variable must result in the output size\nGroups:\t{data.number_groups}\nState variables:\t{len(state_variables)}\noutput_size: {output_size}\n'
|
|
|
self.device = torch.device(data.device_name)
|
|
|
self.device_name = data.device_name
|
|
|
self.plotter = plotter
|
|
@@ -143,7 +147,8 @@ class DINN:
|
|
|
|
|
|
self.parameters_tilda = {}
|
|
|
for parameter in parameter_list:
|
|
|
- self.parameters_tilda.update({parameter: torch.nn.Parameter(torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
+ self.parameters_tilda.update({parameter: torch.nn.Parameter(
|
|
|
+ torch.rand(1, requires_grad=True, device=self.device_name))})
|
|
|
|
|
|
# new model has to be configured and then trained
|
|
|
self.__is_configured = False
|
|
@@ -208,7 +213,8 @@ class DINN:
|
|
|
scheduler_name (str, optional): Name of the scheduler class that is supposed to be used. Defaults to 'CyclicLR'.
|
|
|
verbose (bool, optional): Controles if the configuration process, is to be verbosed. Defaults to False.
|
|
|
"""
|
|
|
- parameter_list = list(self.model.parameters()) + list(self.parameters_tilda.values())
|
|
|
+ parameter_list = list(self.model.parameters()) + \
|
|
|
+ list(self.parameters_tilda.values())
|
|
|
self.epochs = epochs
|
|
|
self.lambda_obs = lambda_obs
|
|
|
self.lambda_physics = lambda_physics
|
|
@@ -219,34 +225,41 @@ class DINN:
|
|
|
self.optimizer = torch.optim.Adam(parameter_list, lr=lr)
|
|
|
if verbose:
|
|
|
print('---------------------------------')
|
|
|
- print(f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
|
|
|
+ print(
|
|
|
+ f' Entered unknown optimizer name: {optimizer_class.name}\n Defaulted to ADAM.')
|
|
|
print('---------------------------------')
|
|
|
optimizer_class = Optimizer.ADAM
|
|
|
|
|
|
match scheduler_class:
|
|
|
case Scheduler.CYCLIC:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.CyclicLR(
|
|
|
+ self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
case Scheduler.CONSTANT:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
|
|
+ self.optimizer, factor=1, total_iters=4)
|
|
|
case Scheduler.LINEAR:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.LinearLR(
|
|
|
+ self.optimizer, start_factor=lr, total_iters=epochs / scheduler_factor)
|
|
|
case Scheduler.POLYNOMIAL:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.PolynomialLR(self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.PolynomialLR(
|
|
|
+ self.optimizer, total_iters=epochs / scheduler_factor, power=1.0)
|
|
|
case _:
|
|
|
- self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.CyclicLR(
|
|
|
+ self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
if verbose:
|
|
|
print('---------------------------------')
|
|
|
- print(f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
|
|
|
+ print(
|
|
|
+ f' Entered unknown scheduler name: {scheduler_class.name}\n Defaulted to CYCLIC.')
|
|
|
print('---------------------------------')
|
|
|
scheduler_class = Scheduler.CYCLIC
|
|
|
|
|
|
if verbose:
|
|
|
- print(f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
|
|
|
+ print(
|
|
|
+ f'\nLearning Rate:\t{lr}\nOptimizer:\t{optimizer_class.name}\nScheduler:\t{scheduler_class.name}\n')
|
|
|
|
|
|
self.__is_configured = True
|
|
|
|
|
|
def train(self,
|
|
|
- plot_I_prediction=False,
|
|
|
create_animation=False,
|
|
|
animation_sample_rate=500,
|
|
|
verbose=False,
|
|
@@ -272,7 +285,8 @@ class DINN:
|
|
|
for epoch in range(self.epochs):
|
|
|
# get the prediction and the fitting residuals
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
- residuals = self.problem.residual(prediction, *self.get_regulated_param_list())
|
|
|
+ residuals = self.problem.residual(
|
|
|
+ prediction, *self.get_regulated_param_list())
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
# calculate loss from the differential system
|
|
@@ -284,7 +298,8 @@ class DINN:
|
|
|
# calculate loss from the dataset
|
|
|
loss_obs = 0
|
|
|
for i, group in enumerate(self.data.group_names):
|
|
|
- loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
|
|
|
+ loss_obs += torch.mean(torch.square(
|
|
|
+ self.data.get_norm(group) - prediction[:, i]))
|
|
|
loss_obs *= self.lambda_obs
|
|
|
|
|
|
if do_split_training:
|
|
@@ -304,33 +319,17 @@ class DINN:
|
|
|
self.obs_losses[epoch] = loss_obs.item()
|
|
|
self.physics_losses[epoch] = loss_physics.item()
|
|
|
for i, parameter in enumerate(self.parameters_tilda.items()):
|
|
|
- self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
|
|
|
+ self.parameters[i][epoch] = self.get_regulated_param(
|
|
|
+ parameter[0]).item()
|
|
|
|
|
|
# do snapshot for prediction animation
|
|
|
if epoch % animation_sample_rate == 0 and create_animation:
|
|
|
# prediction
|
|
|
- """prediction = self.model(self.data.t_batch)
|
|
|
- t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
- groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
-
|
|
|
- plot_labels = [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names]
|
|
|
- background_list = [0 for _ in self.data.group_names] + [1 for _ in self.data.group_names]
|
|
|
- self.plotter.plot(t,
|
|
|
- list(groups) + list(self.data.data),
|
|
|
- plot_labels,
|
|
|
- 'frame',
|
|
|
- f'epoch {epoch}',
|
|
|
- figure_shape=(12, 6),
|
|
|
- is_frame=True,
|
|
|
- is_background=background_list,
|
|
|
- lw=3,
|
|
|
- legend_loc='upper right',
|
|
|
- ylim=(0, self.data.N),
|
|
|
- xlabel='time / days',
|
|
|
- ylabel='amount of people')"""
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
- t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
- groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
+ t = torch.arange(
|
|
|
+ 0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
+ groups = self.data.get_denormalized_data(
|
|
|
+ [prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
|
|
|
plot_labels = ['I_pred', 'I_true']
|
|
|
background_list = [0, 1]
|
|
@@ -349,14 +348,16 @@ class DINN:
|
|
|
|
|
|
# print training advancements
|
|
|
if epoch % 1000 == 0 and verbose:
|
|
|
- print(f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
|
|
|
+ print(
|
|
|
+ f'\nEpoch {epoch} | LR {self.scheduler.get_last_lr()[0]}')
|
|
|
print(f'physics loss:\t\t{loss_physics.item()}')
|
|
|
print(f'observation loss:\t{loss_obs.item()}')
|
|
|
print(f'loss:\t\t\t{loss.item()}')
|
|
|
print('---------------------------------')
|
|
|
if len(self.parameters_tilda.items()) != 0:
|
|
|
for parameter in self.parameters_tilda.items():
|
|
|
- print(f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
|
|
|
+ print(
|
|
|
+ f'{parameter[0]}:\t\t\t{self.parameter_regulator(parameter[1]).item()}')
|
|
|
print('#################################')
|
|
|
|
|
|
# create prediction animation
|
|
@@ -364,25 +365,6 @@ class DINN:
|
|
|
self.plotter.animate(self.data.name + '_animation')
|
|
|
self.plotter.reset_animation()
|
|
|
|
|
|
- if plot_I_prediction:
|
|
|
- prediction = self.model(self.data.t_batch)
|
|
|
- t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
- groups = self.data.get_denormalized_data([prediction[:, i] for i in range(self.data.number_groups)])
|
|
|
-
|
|
|
- plot_labels = ['I_pred', 'I_true']
|
|
|
- background_list = [0, 1]
|
|
|
- self.plotter.plot(t,
|
|
|
- [groups[1]] + [self.data.data[1]],
|
|
|
- plot_labels,
|
|
|
- 'Training_I_prediction',
|
|
|
- f'Prediction of I on JH data',
|
|
|
- figure_shape=(12, 6),
|
|
|
- is_background=background_list,
|
|
|
- lw=3,
|
|
|
- legend_loc='upper right',
|
|
|
- xlabel='time / days',
|
|
|
- ylabel='amount of people')
|
|
|
-
|
|
|
self.__has_trained = True
|
|
|
|
|
|
def plot_training_graphs(self, ground_truth=[]):
|
|
@@ -395,7 +377,8 @@ class DINN:
|
|
|
epochs = np.arange(0, self.epochs, 1)
|
|
|
|
|
|
# plot loss
|
|
|
- self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
|
|
|
+ self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss',
|
|
|
+ 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
|
|
|
|
|
|
# plot parameters
|
|
|
for i, parameter in enumerate(self.parameters):
|
|
@@ -404,7 +387,8 @@ class DINN:
|
|
|
[parameter,
|
|
|
np.ones_like(epochs) * ground_truth[i]],
|
|
|
['prediction', 'ground truth'],
|
|
|
- self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
+ self.data.name + '_' +
|
|
|
+ list(self.parameters_tilda.items())[i][0],
|
|
|
list(self.parameters_tilda.items())[i][0],
|
|
|
(6, 6),
|
|
|
is_background=[0, 1],
|
|
@@ -413,8 +397,10 @@ class DINN:
|
|
|
self.plotter.plot(epochs,
|
|
|
[parameter],
|
|
|
['prediction'],
|
|
|
- self.data.name + '_' + list(self.parameters_tilda.items())[i][0],
|
|
|
- list(self.parameters_tilda.items())[i][0], (6, 6),
|
|
|
+ self.data.name + '_' +
|
|
|
+ list(self.parameters_tilda.items())[i][0],
|
|
|
+ list(self.parameters_tilda.items())[
|
|
|
+ i][0], (6, 6),
|
|
|
xlabel='epochs',
|
|
|
plot_legend=False)
|
|
|
|
|
@@ -434,9 +420,11 @@ class DINN:
|
|
|
if save_predictions:
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
for i, group in enumerate(self.data.group_names):
|
|
|
- t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
|
|
|
+ t = torch.linspace(
|
|
|
+ 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0]).detach().cpu().numpy()
|
|
|
true = self.data.get_group(group).detach().cpu().numpy()
|
|
|
- pred = self.data.get_denormalized_data([prediction[:, i]])[0].detach().cpu().numpy()
|
|
|
+ pred = self.data.get_denormalized_data([prediction[:, i]])[
|
|
|
+ 0].detach().cpu().numpy()
|
|
|
print(t.shape, true.shape)
|
|
|
with open(f'./results/I_predictions/{title}_I_prediction.csv', 'w', newline='') as csvfile:
|
|
|
writer = csv.writer(csvfile, delimiter=',')
|
|
@@ -447,12 +435,14 @@ class DINN:
|
|
|
def plot_state_variables(self):
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
for i in range(self.data.number_groups, self.data.number_groups + self.number_state_variables):
|
|
|
- t = torch.linspace(0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
|
|
|
+ t = torch.linspace(
|
|
|
+ 0, self.data.t_raw[-1].item(), self.data.t_raw.shape[0])
|
|
|
self.plotter.plot(t,
|
|
|
[prediction[:, i]],
|
|
|
[self.__state_variables[i - self.data.number_groups]],
|
|
|
f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
|
|
|
- self.__state_variables[i - self.data.number_groups],
|
|
|
+ self.__state_variables[i -
|
|
|
+ self.data.number_groups],
|
|
|
figure_shape=(12, 6),
|
|
|
plot_legend=True,
|
|
|
xlabel='time / days')
|