phillip.rothenbeck преди 4 месеца
родител
ревизия
0cdc965e59
променени са 1 файла, в които са добавени 268 реда и са изтрити 0 реда
  1. 268 0
      paper_sir_dinn.ipynb

+ 268 - 0
paper_sir_dinn.ipynb

@@ -0,0 +1,268 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import torch\n",
+    "\n",
+    "from src.dataset import PandemicDataset, Norms\n",
+    "from src.problem import SIRAlphaProblem\n",
+    "from src.dinn import DINN, Scheduler, Activation, Optimizer\n",
+    "from src.plotter import Plotter\n",
+    "# torch.manual_seed(18361969907809597111)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "covid_data = np.genfromtxt(f'./datasets/SIR_Paper_Germany_14.csv', delimiter=',')\n",
+    "dataset = PandemicDataset(\"SIR_Paper_Germany_14\", \n",
+    "                          ['S', 'I', 'R'], \n",
+    "                          70000000, \n",
+    "                          *covid_data, \n",
+    "                          norm_name=Norms.MIN_MAX) \n",
+    "\n",
+    "ALPHA = 0.07"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Configure"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Learning Rate:\t0.001\n",
+      "Optimizer:\tADAM\n",
+      "Scheduler:\tPOLYNOMIAL\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "plotter = Plotter()\n",
+    "\n",
+    "problem = SIRAlphaProblem(dataset, ALPHA)\n",
+    "dinn = DINN(3, \n",
+    "            dataset, \n",
+    "            ['beta'], \n",
+    "            problem, \n",
+    "            plotter,\n",
+    "            hidden_size=64, \n",
+    "            hidden_layers=12, \n",
+    "            activation_layer=torch.nn.Tanh(),\n",
+    "            activation_output=Activation.POWER)\n",
+    "dinn.configure_training(1e-3, \n",
+    "                        15000, \n",
+    "                        scheduler_class=Scheduler.POLYNOMIAL, \n",
+    "                        optimizer_class=Optimizer.ADAM,\n",
+    "                        lambda_obs=1e1, \n",
+    "                        verbose=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch seed: 4462685050789905541\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Epoch 0 | LR 0.0009999333333333333\n",
+      "physics loss:\t\t4.160715663590364e-05\n",
+      "observation loss:\t9.556413235291215\n",
+      "loss:\t\t\t9.55645484244785\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.6011705994606018\n",
+      "#################################\n",
+      "\n",
+      "Epoch 1000 | LR 0.0009332666666666524\n",
+      "physics loss:\t\t0.0009664472242609449\n",
+      "observation loss:\t0.0017112362767341596\n",
+      "loss:\t\t\t0.0026776835009951045\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.2550368905067444\n",
+      "#################################\n",
+      "\n",
+      "Epoch 2000 | LR 0.0008665999999999833\n",
+      "physics loss:\t\t0.001247615250807723\n",
+      "observation loss:\t0.0040488860826972005\n",
+      "loss:\t\t\t0.005296501333504924\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.253487229347229\n",
+      "#################################\n",
+      "\n",
+      "Epoch 3000 | LR 0.000799933333333315\n",
+      "physics loss:\t\t0.00044841892062854347\n",
+      "observation loss:\t0.012716515333301339\n",
+      "loss:\t\t\t0.013164934253929882\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.25267982482910156\n",
+      "#################################\n",
+      "\n",
+      "Epoch 4000 | LR 0.000733266666666649\n",
+      "physics loss:\t\t0.001045761999967263\n",
+      "observation loss:\t0.00025607909237615716\n",
+      "loss:\t\t\t0.00130184109234342\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.24764902889728546\n",
+      "#################################\n",
+      "\n",
+      "Epoch 5000 | LR 0.0006665999999999824\n",
+      "physics loss:\t\t0.0014444880255704616\n",
+      "observation loss:\t0.00014856713535789037\n",
+      "loss:\t\t\t0.001593055160928352\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.2412654012441635\n",
+      "#################################\n",
+      "\n",
+      "Epoch 6000 | LR 0.0005999333333333129\n",
+      "physics loss:\t\t0.001605141155981103\n",
+      "observation loss:\t0.00013662428319793931\n",
+      "loss:\t\t\t0.0017417654391790425\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.23936264216899872\n",
+      "#################################\n",
+      "\n",
+      "Epoch 7000 | LR 0.0005332666666666435\n",
+      "physics loss:\t\t0.00190968670757966\n",
+      "observation loss:\t0.00012329566674995849\n",
+      "loss:\t\t\t0.0020329823743296185\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.23441843688488007\n",
+      "#################################\n",
+      "\n",
+      "Epoch 8000 | LR 0.0004665999999999755\n",
+      "physics loss:\t\t0.0019683097313949755\n",
+      "observation loss:\t0.00011957883463252119\n",
+      "loss:\t\t\t0.0020878885660274966\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.2345646470785141\n",
+      "#################################\n",
+      "\n",
+      "Epoch 9000 | LR 0.0003999333333333152\n",
+      "physics loss:\t\t0.0023533661719838616\n",
+      "observation loss:\t0.00011510529573689898\n",
+      "loss:\t\t\t0.0024684714677207604\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.22906526923179626\n",
+      "#################################\n",
+      "\n",
+      "Epoch 10000 | LR 0.00033326666666665735\n",
+      "physics loss:\t\t0.0025831233331882735\n",
+      "observation loss:\t0.00011076522614630127\n",
+      "loss:\t\t\t0.002693888559334575\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.22521229088306427\n",
+      "#################################\n",
+      "\n",
+      "Epoch 11000 | LR 0.00026659999999999987\n",
+      "physics loss:\t\t0.002667692420961139\n",
+      "observation loss:\t0.00010629412421269853\n",
+      "loss:\t\t\t0.0027739865451738375\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.22304710745811462\n",
+      "#################################\n",
+      "\n",
+      "Epoch 12000 | LR 0.0001999333333333354\n",
+      "physics loss:\t\t0.002831200195029232\n",
+      "observation loss:\t0.00015541454626586456\n",
+      "loss:\t\t\t0.0029866147412950968\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.2231626659631729\n",
+      "#################################\n",
+      "\n",
+      "Epoch 13000 | LR 0.0001332666666666689\n",
+      "physics loss:\t\t0.002496475121496725\n",
+      "observation loss:\t9.74420291549779e-05\n",
+      "loss:\t\t\t0.002593917150651703\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.22342386841773987\n",
+      "#################################\n",
+      "\n",
+      "Epoch 14000 | LR 6.660000000000055e-05\n",
+      "physics loss:\t\t0.0024228574550409178\n",
+      "observation loss:\t9.359204257759588e-05\n",
+      "loss:\t\t\t0.0025164494976185137\n",
+      "---------------------------------\n",
+      "beta:\t\t\t0.2237681895494461\n",
+      "#################################\n"
+     ]
+    }
+   ],
+   "source": [
+    "dinn.train(create_animation=True, \n",
+    "           plot_I_prediction=True, \n",
+    "           verbose=True) "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dinn.plot_training_graphs()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "PINN",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}