|
@@ -158,6 +158,8 @@ class DINN:
|
|
|
match scheduler_name:
|
|
|
case 'CyclicLR':
|
|
|
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=lr * 10, max_lr=lr * 1e3, step_size_up=1000, mode="exp_range", gamma=0.85, cycle_momentum=False)
|
|
|
+ case 'ConstantLR':
|
|
|
+ self.scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1, total_iters=4)
|
|
|
case 'LinearLR':
|
|
|
self.scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=lr, total_iters=epochs/scheduler_factor)
|
|
|
case 'PolynomialLR':
|
|
@@ -302,16 +304,17 @@ class DINN:
|
|
|
plot_legend=False)
|
|
|
|
|
|
def plot_state_variables(self):
|
|
|
+ prediction = self.model(self.data.t_batch)
|
|
|
+ groups = [prediction[:, i] for i in range(self.data.number_groups)]
|
|
|
+ fore_background = [0] + [1 for _ in groups]
|
|
|
for i in range(self.data.number_groups, self.data.number_groups+self.number_state_variables):
|
|
|
- prediction = self.model(self.data.t_batch)
|
|
|
- groups = [prediction[:, i] for i in range(self.data.number_groups)]
|
|
|
t = torch.arange(0, self.data.t_raw[-1].item(), (self.data.t_raw[-1] / self.data.t_raw.shape[0]).item())
|
|
|
self.plotter.plot(t,
|
|
|
[prediction[:, i]] + groups,
|
|
|
[self.__state_variables[i-self.data.number_groups]] + self.data.group_names,
|
|
|
f'{self.data.name}_{self.__state_variables[i-self.data.number_groups]}',
|
|
|
self.__state_variables[i-self.data.number_groups],
|
|
|
- is_background=[0, 1, 1],
|
|
|
+ is_background=fore_background,
|
|
|
figure_shape=(12, 6),
|
|
|
plot_legend=True,
|
|
|
xlabel='time / days')
|