Browse Source

seperatly train R_t for Germany

phillip.rothenbeck 4 months ago
parent
commit
c005bdab3e
1 changed files with 52 additions and 0 deletions
  1. 52 0
      germany_training.py

+ 52 - 0
germany_training.py

@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import csv
+
+from src.dataset import PandemicDataset, Norms
+from src.problem import ReducedSIRProblem
+from src.dinn import DINN, Scheduler, Activation
+
+ALPHA = [1/14, 1/5]
+NORM = [Norms.POPULATION, Norms.CONSTANT]
+
+ITERATIONS = 10
+
+for iteration in range(ITERATIONS):
+    for i, alpha in enumerate(ALPHA):
+        print(f'training for Germany, alpha: {alpha}, iter: {iteration}')
+
+        covid_data = np.genfromtxt(f'./datasets/I_RKI_Germany_1_{int(1/alpha)}.csv', delimiter=',')
+        dataset = PandemicDataset('Germany', 
+                                  ['I'], 
+                                  83100000, 
+                                  *covid_data, 
+                                  norm_name=NORM[i], 
+                                  C=10**6, 
+                                  use_scaled_time=True)
+        problem = ReducedSIRProblem(dataset, alpha)
+
+        dinn = DINN(2, 
+                    dataset, 
+                    [], 
+                    problem, 
+                    None, 
+                    state_variables=['R_t'], 
+                    hidden_size=100, 
+                    hidden_layers=4, 
+                    activation_layer=torch.nn.Tanh(),
+                    activation_output=Activation.POWER)
+
+        dinn.configure_training(1e-3, 
+                                25000, 
+                                scheduler_class=Scheduler.POLYNOMIAL, 
+                                lambda_obs=1e4,
+                                lambda_physics=1e-6, 
+                                verbose=True)
+        dinn.train(verbose=True, do_split_training=True, start_split=15000)
+
+        dinn.save_training_process(f'Germany_{i}_{iteration}')
+
+        r_t = dinn.get_output(1).detach().cpu().numpy()
+        with open(f'./results/Germany_{i}_{iteration}.csv', 'w', newline='') as csvfile:
+            writer = csv.writer(csvfile, delimiter=',')
+            writer.writerow(r_t)