In [1]:
import numpy as np
import csv

from src.dataset import PandemicDataset
from src.problem import SIRProblem
from src.dinn import DINN, Scheduler
from src.plotter import Plotter

In [2]:
state_lookup = {'Schleswig_Holstein' : 2897000,
 'Hamburg' : 1841000, 
 'Niedersachsen' : 7982000, 
 'Bremen' : 569352,
 'Nordrhein_Westfalen' : 17930000,
 'Hessen' : 6266000,
 'Rheinland_Pfalz' : 4085000,
 'Baden_Wuerttemberg' : 11070000,
 'Bayern' : 13080000,
 'Saarland' : 990509,
 'Berlin' : 3645000,
 'Brandenburg' : 2641000,
 'Mecklenburg_Vorpommern' : 1610000,
 'Sachsen' : 4078000,
 'Sachsen_Anhalt' : 2208000,
 'Thueringen' : 2143000, 
 'Germany' : 83100000}

In [3]:
def get_error(y, y_ref):
 err = []
 for i in range(len(y)):
 diff = y[i] - y_ref
 error = 1/3 * (np.linalg.norm(diff[0]) / np.linalg.norm(y_ref[0]) + 
 np.linalg.norm(diff[1]) / np.linalg.norm(y_ref[1]) + 
 np.linalg.norm(diff[2]) / np.linalg.norm(y_ref[2]))
 err.append(error)
 return np.array(err).mean(axis=0)

In [4]:
state_params = {}
states = ["Bremen"]
for state in state_lookup.keys():
 if state not in states:
 continue
 predictions = []
 covid_data = np.genfromtxt(f'./datasets/SIR_RKI_{state}_1_14.csv', delimiter=',')
 for i in range(5):
 if i == 0:
 if state not in state_params:
 state_params.update({state : []})
 print(state, i)
 
 dataset = PandemicDataset(state, ['S', 'I', 'R'], state_lookup[state], *covid_data)

 problem = SIRProblem(dataset)
 plotter = Plotter()

 dinn = DINN(3, dataset, ['alpha', 'beta'], problem, plotter)

 dinn.configure_training(1e-3, 10000, scheduler_class=Scheduler.POLYNOMIAL)
 dinn.train(create_animation=True)

 dinn.save_training_process(f"SIR_{state}", save_predictions=False)
 state_params[state].append((dinn.get_regulated_param('alpha').item(), dinn.get_regulated_param('beta').item()))
 pred = (dinn.get_output(0), 
 dinn.get_output(1), 
 dinn.get_output(2))
 predictions.append([d.detach().cpu().numpy() for d in dataset.get_denormalized_data(pred)])
 print(state, "&", '{0:.4f}'.format(get_error(np.array(predictions), np.array([d for d in covid_data[1:]]))))


Bremen 0
Bremen 1
Bremen 2
Bremen 3
Bremen 4
Bremen & 0.0910


In [5]:
for state in state_lookup.keys():
 if state not in states:
 continue
 state_matrix = np.array(state_params[state])
 with open(f'./results/{state}_parameters.csv', 'w', newline='') as csvfile:
 writer = csv.writer(csvfile, delimiter=',')
 for row in state_matrix:
 writer.writerow(row)