In [1]:
import torch
import numpy as np

from src.dataset import PandemicDataset, Norms
from src.problem import ReducedSIRProblem
from src.dinn import DINN, Scheduler, Activation
from src.plotter import Plotter

In [2]:
alpha = 1/3

In [3]:
plotter = Plotter()
covid_data = np.genfromtxt('./datasets/I_data.csv', delimiter=',')
dataset = PandemicDataset('synth_sir', ['I'], 7.6e6, *covid_data, norm_name=Norms.CONSTANT, use_scaled_time=True)

problem = ReducedSIRProblem(dataset, alpha)

dinn = DINN(2, 
 dataset, 
 [], 
 problem, 
 plotter, 
 state_variables=['R_t'], 
 hidden_size=100, 
 hidden_layers=4, 
 activation_layer=torch.nn.Tanh(), 
 activation_output=Activation.POWER,
 use_glorot_initialization=True)
dinn.configure_training(1e-3, 
 20000, 
 lambda_physics=1e-6,
 scheduler_class=Scheduler.POLYNOMIAL, 
 verbose=True)

dinn.train(create_animation=True, verbose=True, do_split_training=True)
dinn.plot_training_graphs()
dinn.plot_state_variables()



Learning Rate:	0.001
Optimizer:	ADAM
Scheduler:	POLYNOMIAL

torch seed: 9948651162532304809
dIdt (min | max): -4.3066202124464326e-06 | -4.301004537410336e-06, I(min | max): 0.0019655212257585625 | 0.0019698153581950706, R_t(min | max): 0.9462157915068161 | 0.9471622551059369
I_residual (min | max): 0.00515376093227558 | 0.0052576320047351445


Epoch 0 | LR 0.00099995
physics loss:		4.064907813105731e-09
observation loss:	0.5382289321642364
loss:			0.5382289321642364
---------------------------------
dIdt (min | max): 0.029735210628132336 | 0.03037917883193586, I(min | max): 0.5932272512312444 | 0.6232452147125436, R_t(min | max): 0.8309403277604908 | 0.8333196479677305
I_residual (min | max): 5.011096179934845 | 5.189888191104072

dIdt (min | max): 0.04699786892160773 | 0.048385180183686316, I(min | max): 0.5854795313251806 | 0.6329777429673413, R_t(min | max): 1.996863750612249 | 2.020590986922116
I_residual (min | max): -31.290912309605147 | -29.630473406013607


Epoch 1000 | LR 0.

In [5]:
synth_r_t = np.zeros(150, dtype=np.float64)
for i, time in enumerate(range(150)):
 synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35
r_t = dinn.get_output(1).detach().cpu().numpy()
plotter.plot(dataset.t_raw.detach().cpu().numpy(), [r_t, synth_r_t], ["pred", "true"], "test", "R_t", (12, 6))