|
@@ -83,6 +83,8 @@ class DINN:
|
|
|
self.epochs = None
|
|
|
|
|
|
self.losses = np.zeros(1)
|
|
|
+ self.obs_losses = np.zeros(1)
|
|
|
+ self.physics_losses = np.zeros(1)
|
|
|
self.parameters = [np.zeros(1) for _ in range(len(parameter_list))]
|
|
|
|
|
|
self.frames = []
|
|
@@ -139,6 +141,8 @@ class DINN:
|
|
|
|
|
|
# arrays to hold values for plotting
|
|
|
self.losses = np.zeros(epochs)
|
|
|
+ self.obs_losses = np.zeros(epochs)
|
|
|
+ self.physics_losses = np.zeros(epochs)
|
|
|
self.parameters = [np.zeros(epochs) for _ in self.parameters]
|
|
|
|
|
|
for epoch in range(epochs):
|
|
@@ -155,7 +159,7 @@ class DINN:
|
|
|
|
|
|
# calculate loss from the dataset
|
|
|
loss_obs = 0
|
|
|
- for i, group in enumerate(self.data.get_group_names()):
|
|
|
+ for i, group in enumerate(self.data.group_names):
|
|
|
loss_obs += torch.mean(torch.square(self.data.get_norm(group) - prediction[:, i]))
|
|
|
|
|
|
loss = loss_physics + loss_obs
|
|
@@ -166,6 +170,8 @@ class DINN:
|
|
|
|
|
|
# append values for plotting
|
|
|
self.losses[epoch] = loss.item()
|
|
|
+ self.obs_losses[epoch] = loss_obs.item()
|
|
|
+ self.physics_losses[epoch] = loss_physics.item()
|
|
|
for i, parameter in enumerate(self.parameters_tilda.items()):
|
|
|
self.parameters[i][epoch] = self.get_regulated_param(parameter[0]).item()
|
|
|
|
|
@@ -174,11 +180,10 @@ class DINN:
|
|
|
# prediction
|
|
|
prediction = self.model(self.data.t_batch)
|
|
|
t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
-
|
|
|
- groups = self.problem.denormalization(prediction)
|
|
|
+ groups = self.data.get_denormalized_data([prediction[:, 0], prediction[:, 1], prediction[:, 2]])
|
|
|
self.plotter.plot(t,
|
|
|
- groups + tuple(self.data.get_data()),
|
|
|
- [name + '_pred' for name in self.data.get_group_names()] + [name + '_true' for name in self.data.get_group_names()],
|
|
|
+ list(groups) + list(self.data.data),
|
|
|
+ [name + '_pred' for name in self.data.group_names] + [name + '_true' for name in self.data.group_names],
|
|
|
'frame',
|
|
|
f'epoch {epoch}',
|
|
|
figure_shape=(12, 6),
|
|
@@ -215,7 +220,7 @@ class DINN:
|
|
|
epochs = np.arange(0, self.epochs, 1)
|
|
|
|
|
|
# plot loss
|
|
|
- self.plotter.plot(epochs, [self.losses], ['loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=False, xlabel='epochs')
|
|
|
+ self.plotter.plot(epochs, [self.losses, self.obs_losses, self.physics_losses], ['loss', 'observation loss', 'physics loss'], self.data.name + '_loss', 'Loss', (6, 6), y_log_scale=True, plot_legend=True, xlabel='epochs')
|
|
|
|
|
|
# plot parameters
|
|
|
for i, parameter in enumerate(self.parameters):
|