12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import torch
- import numpy as np
- import csv
- from src.dataset import PandemicDataset, Norms
- from src.problem import ReducedSIRProblem
- from src.dinn import DINN, Scheduler, Activation
- ALPHA = [1/14, 1/5]
- NORM = [Norms.POPULATION, Norms.CONSTANT]
- ITERATIONS = 10
- for iteration in range(ITERATIONS):
- for i, alpha in enumerate(ALPHA):
- print(f'training for Germany, alpha: {alpha}, iter: {iteration}')
- covid_data = np.genfromtxt(f'./datasets/I_RKI_Germany_1_{int(1/alpha)}.csv', delimiter=',')
- dataset = PandemicDataset('Germany',
- ['I'],
- 83100000,
- *covid_data,
- norm_name=NORM[i],
- C=10**6,
- use_scaled_time=True)
- problem = ReducedSIRProblem(dataset, alpha)
- dinn = DINN(2,
- dataset,
- [],
- problem,
- None,
- state_variables=['R_t'],
- hidden_size=100,
- hidden_layers=4,
- activation_layer=torch.nn.Tanh(),
- activation_output=Activation.POWER)
- dinn.configure_training(1e-3,
- 25000,
- scheduler_class=Scheduler.POLYNOMIAL,
- lambda_obs=1e4,
- lambda_physics=1e-6,
- verbose=True)
- dinn.train(verbose=True, do_split_training=True, start_split=15000)
- dinn.save_training_process(f'Germany_{i}_{iteration}')
- r_t = dinn.get_output(1).detach().cpu().numpy()
- with open(f'./results/Germany_{i}_{iteration}.csv', 'w', newline='') as csvfile:
- writer = csv.writer(csvfile, delimiter=',')
- writer.writerow(r_t)
|