germany_training.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import torch
  2. import numpy as np
  3. import csv
  4. from src.dataset import PandemicDataset, Norms
  5. from src.problem import ReducedSIRProblem
  6. from src.dinn import DINN, Scheduler, Activation
  7. ALPHA = [1/14, 1/5]
  8. NORM = [Norms.POPULATION, Norms.CONSTANT]
  9. ITERATIONS = 10
  10. for iteration in range(ITERATIONS):
  11. for i, alpha in enumerate(ALPHA):
  12. print(f'training for Germany, alpha: {alpha}, iter: {iteration}')
  13. covid_data = np.genfromtxt(f'./datasets/I_RKI_Germany_1_{int(1/alpha)}.csv', delimiter=',')
  14. dataset = PandemicDataset('Germany',
  15. ['I'],
  16. 83100000,
  17. *covid_data,
  18. norm_name=NORM[i],
  19. C=10**6,
  20. use_scaled_time=True)
  21. problem = ReducedSIRProblem(dataset, alpha)
  22. dinn = DINN(2,
  23. dataset,
  24. [],
  25. problem,
  26. None,
  27. state_variables=['R_t'],
  28. hidden_size=100,
  29. hidden_layers=4,
  30. activation_layer=torch.nn.Tanh(),
  31. activation_output=Activation.POWER)
  32. dinn.configure_training(1e-3,
  33. 25000,
  34. scheduler_class=Scheduler.POLYNOMIAL,
  35. lambda_obs=1e4,
  36. lambda_physics=1e-6,
  37. verbose=True)
  38. dinn.train(verbose=True, do_split_training=True, start_split=15000)
  39. dinn.save_training_process(f'Germany_{i}_{iteration}')
  40. r_t = dinn.get_output(1).detach().cpu().numpy()
  41. with open(f'./results/Germany_{i}_{iteration}.csv', 'w', newline='') as csvfile:
  42. writer = csv.writer(csvfile, delimiter=',')
  43. writer.writerow(r_t)