{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import csv\n", "\n", "from src.dataset import PandemicDataset\n", "from src.problem import SIRProblem\n", "from src.dinn import DINN, Scheduler\n", "from src.plotter import Plotter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "state_lookup = {'Schleswig_Holstein' : 2897000,\n", " 'Hamburg' : 1841000, \n", " 'Niedersachsen' : 7982000, \n", " 'Bremen' : 569352,\n", " 'Nordrhein_Westfalen' : 17930000,\n", " 'Hessen' : 6266000,\n", " 'Rheinland_Pfalz' : 4085000,\n", " 'Baden_Wuerttemberg' : 11070000,\n", " 'Bayern' : 13080000,\n", " 'Saarland' : 990509,\n", " 'Berlin' : 3645000,\n", " 'Brandenburg' : 2641000,\n", " 'Mecklenburg_Vorpommern' : 1610000,\n", " 'Sachsen' : 4078000,\n", " 'Sachsen_Anhalt' : 2208000,\n", " 'Thueringen' : 2143000, \n", " 'Germany' : 83100000}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_error(y, y_ref):\n", " err = []\n", " for i in range(len(y)):\n", " diff = y[i] - y_ref\n", " error = 1/3 * (np.linalg.norm(diff[0]) / np.linalg.norm(y_ref[0]) + \n", " np.linalg.norm(diff[1]) / np.linalg.norm(y_ref[1]) + \n", " np.linalg.norm(diff[2]) / np.linalg.norm(y_ref[2]))\n", " err.append(error)\n", " return np.array(err).mean(axis=0)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Bremen 0\n", "Bremen 1\n", "Bremen 2\n", "Bremen 3\n", "Bremen 4\n", "Bremen & 0.0910\n" ] } ], "source": [ "state_params = {}\n", "states = [\"Bremen\"]\n", "for state in state_lookup.keys():\n", " if state not in states:\n", " continue\n", " predictions = []\n", " covid_data = np.genfromtxt(f'./datasets/SIR_RKI_{state}_1_14.csv', delimiter=',')\n", " for i in range(5):\n", " if i == 0:\n", " if state not in state_params:\n", " state_params.update({state : []})\n", " print(state, i)\n", " \n", " dataset = PandemicDataset(state, ['S', 'I', 'R'], state_lookup[state], *covid_data)\n", "\n", " problem = SIRProblem(dataset)\n", " plotter = Plotter()\n", "\n", " dinn = DINN(3, dataset, ['alpha', 'beta'], problem, plotter)\n", "\n", " dinn.configure_training(1e-3, 10000, scheduler_class=Scheduler.POLYNOMIAL)\n", " dinn.train(create_animation=True)\n", "\n", " dinn.save_training_process(f\"SIR_{state}\", save_predictions=False)\n", " state_params[state].append((dinn.get_regulated_param('alpha').item(), dinn.get_regulated_param('beta').item()))\n", " pred = (dinn.get_output(0), \n", " dinn.get_output(1), \n", " dinn.get_output(2))\n", " predictions.append([d.detach().cpu().numpy() for d in dataset.get_denormalized_data(pred)])\n", " print(state, \"&\", '{0:.4f}'.format(get_error(np.array(predictions), np.array([d for d in covid_data[1:]]))))\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "for state in state_lookup.keys():\n", " if state not in states:\n", " continue\n", " state_matrix = np.array(state_params[state])\n", " with open(f'./results/{state}_parameters.csv', 'w', newline='') as csvfile:\n", " writer = csv.writer(csvfile, delimiter=',')\n", " for row in state_matrix:\n", " writer.writerow(row)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "PINN", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }