states_training.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch
  2. import numpy as np
  3. import csv
  4. import sys
  5. from src.dataset import PandemicDataset, Norms
  6. from src.problem import ReducedSIRProblem
  7. from src.dinn import DINN, Scheduler, Activation
  8. ALPHA = [1 / 14, 1 / 5]
  9. DO_STATES = True
  10. DO_SYNTHETIC = False
  11. ITERATIONS = 13
  12. state_starting_index = 0
  13. if "1" in sys.argv:
  14. state_starting_index = 8
  15. STATE_LOOKUP = {'Schleswig_Holstein': 2897000,
  16. 'Hamburg': 1841000,
  17. 'Niedersachsen': 7982000,
  18. 'Bremen': 569352,
  19. 'Nordrhein_Westfalen': 17930000,
  20. 'Hessen': 6266000,
  21. 'Rheinland_Pfalz': 4085000,
  22. 'Baden_Wuerttemberg': 11070000,
  23. 'Bayern': 13080000,
  24. 'Saarland': 990509,
  25. 'Berlin': 3645000,
  26. 'Brandenburg': 2641000,
  27. 'Mecklenburg_Vorpommern': 1610000,
  28. 'Sachsen': 4078000,
  29. 'Sachsen_Anhalt': 2208000,
  30. 'Thueringen': 2143000}
  31. if DO_SYNTHETIC:
  32. alpha = 1 / 3
  33. covid_data = np.genfromtxt(f'./datasets/I_data.csv', delimiter=',')
  34. for i in range(ITERATIONS):
  35. dataset = PandemicDataset('Synthetic I',
  36. ['I'],
  37. 7.6e6,
  38. *covid_data,
  39. norm_name=Norms.CONSTANT,
  40. use_scaled_time=True)
  41. problem = ReducedSIRProblem(dataset, alpha)
  42. dinn = DINN(2,
  43. dataset,
  44. [],
  45. problem,
  46. None,
  47. state_variables=['R_t'],
  48. hidden_size=100,
  49. hidden_layers=4,
  50. activation_layer=torch.nn.Tanh(),
  51. activation_output=Activation.POWER)
  52. dinn.configure_training(1e-3,
  53. 20000,
  54. scheduler_class=Scheduler.POLYNOMIAL,
  55. lambda_physics=1e-6,
  56. verbose=True)
  57. dinn.train(verbose=True, do_split_training=True)
  58. dinn.save_training_process(f'synthetic_{i}')
  59. # r_t = dinn.get_output(1).detach().cpu().numpy()
  60. # with open(f'./results/synthetic_{i}.csv', 'w', newline='') as csvfile:
  61. # writer = csv.writer(csvfile, delimiter=',')
  62. # writer.writerow(r_t)
  63. for iteration in range(ITERATIONS):
  64. if iteration <= 2:
  65. print('skip first three iteration, as it was already done')
  66. continue
  67. if DO_STATES:
  68. for state_idx in range(state_starting_index, state_starting_index + 8):
  69. state = list(STATE_LOOKUP.keys())[state_idx]
  70. exclude = ['Schleswig_Holstein', 'Hamburg', 'Niedersachsen']
  71. if iteration == 3 and state in exclude:
  72. print(f'skip in {state} third iteration, as it was already done')
  73. continue
  74. for i, alpha in enumerate(ALPHA):
  75. print(f'training for {state} ({state_idx}), alpha: {alpha}, iter: {iteration}')
  76. covid_data = np.genfromtxt(f'./datasets/I_RKI_{state}_1_{int(1/alpha)}.csv', delimiter=',')
  77. dataset = PandemicDataset(state, ['I'], STATE_LOOKUP[state], *covid_data, norm_name=Norms.CONSTANT, use_scaled_time=True)
  78. problem = ReducedSIRProblem(dataset, alpha)
  79. dinn = DINN(2,
  80. dataset,
  81. [],
  82. problem,
  83. None,
  84. state_variables=['R_t'],
  85. hidden_size=100,
  86. hidden_layers=4,
  87. activation_layer=torch.nn.Tanh(),
  88. activation_output=Activation.POWER)
  89. dinn.configure_training(1e-3,
  90. 25000,
  91. scheduler_class=Scheduler.POLYNOMIAL,
  92. lambda_obs=1e2,
  93. lambda_physics=1e-6,
  94. verbose=True)
  95. dinn.train(verbose=True, do_split_training=True)
  96. dinn.save_training_process(f'{state}_{i}_{iteration}')
  97. r_t = dinn.get_output(1).detach().cpu().numpy()
  98. with open(f'./results/{state}_{i}_{iteration}.csv', 'w', newline='') as csvfile:
  99. writer = csv.writer(csvfile, delimiter=',')
  100. writer.writerow(r_t)