|
@@ -0,0 +1,268 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import torch\n",
|
|
|
+ "\n",
|
|
|
+ "from src.dataset import PandemicDataset, Norms\n",
|
|
|
+ "from src.problem import SIRAlphaProblem\n",
|
|
|
+ "from src.dinn import DINN, Scheduler, Activation, Optimizer\n",
|
|
|
+ "from src.plotter import Plotter\n",
|
|
|
+ "# torch.manual_seed(18361969907809597111)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Load Data"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "covid_data = np.genfromtxt(f'./datasets/SIR_Paper_Germany_14.csv', delimiter=',')\n",
|
|
|
+ "dataset = PandemicDataset(\"SIR_Paper_Germany_14\", \n",
|
|
|
+ " ['S', 'I', 'R'], \n",
|
|
|
+ " 70000000, \n",
|
|
|
+ " *covid_data, \n",
|
|
|
+ " norm_name=Norms.MIN_MAX) \n",
|
|
|
+ "\n",
|
|
|
+ "ALPHA = 0.07"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Configure"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\n",
|
|
|
+ "Learning Rate:\t0.001\n",
|
|
|
+ "Optimizer:\tADAM\n",
|
|
|
+ "Scheduler:\tPOLYNOMIAL\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "plotter = Plotter()\n",
|
|
|
+ "\n",
|
|
|
+ "problem = SIRAlphaProblem(dataset, ALPHA)\n",
|
|
|
+ "dinn = DINN(3, \n",
|
|
|
+ " dataset, \n",
|
|
|
+ " ['beta'], \n",
|
|
|
+ " problem, \n",
|
|
|
+ " plotter,\n",
|
|
|
+ " hidden_size=64, \n",
|
|
|
+ " hidden_layers=12, \n",
|
|
|
+ " activation_layer=torch.nn.Tanh(),\n",
|
|
|
+ " activation_output=Activation.POWER)\n",
|
|
|
+ "dinn.configure_training(1e-3, \n",
|
|
|
+ " 15000, \n",
|
|
|
+ " scheduler_class=Scheduler.POLYNOMIAL, \n",
|
|
|
+ " optimizer_class=Optimizer.ADAM,\n",
|
|
|
+ " lambda_obs=1e1, \n",
|
|
|
+ " verbose=True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "torch seed: 4462685050789905541\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\n",
|
|
|
+ "Epoch 0 | LR 0.0009999333333333333\n",
|
|
|
+ "physics loss:\t\t4.160715663590364e-05\n",
|
|
|
+ "observation loss:\t9.556413235291215\n",
|
|
|
+ "loss:\t\t\t9.55645484244785\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.6011705994606018\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 1000 | LR 0.0009332666666666524\n",
|
|
|
+ "physics loss:\t\t0.0009664472242609449\n",
|
|
|
+ "observation loss:\t0.0017112362767341596\n",
|
|
|
+ "loss:\t\t\t0.0026776835009951045\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.2550368905067444\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 2000 | LR 0.0008665999999999833\n",
|
|
|
+ "physics loss:\t\t0.001247615250807723\n",
|
|
|
+ "observation loss:\t0.0040488860826972005\n",
|
|
|
+ "loss:\t\t\t0.005296501333504924\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.253487229347229\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 3000 | LR 0.000799933333333315\n",
|
|
|
+ "physics loss:\t\t0.00044841892062854347\n",
|
|
|
+ "observation loss:\t0.012716515333301339\n",
|
|
|
+ "loss:\t\t\t0.013164934253929882\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.25267982482910156\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 4000 | LR 0.000733266666666649\n",
|
|
|
+ "physics loss:\t\t0.001045761999967263\n",
|
|
|
+ "observation loss:\t0.00025607909237615716\n",
|
|
|
+ "loss:\t\t\t0.00130184109234342\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.24764902889728546\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 5000 | LR 0.0006665999999999824\n",
|
|
|
+ "physics loss:\t\t0.0014444880255704616\n",
|
|
|
+ "observation loss:\t0.00014856713535789037\n",
|
|
|
+ "loss:\t\t\t0.001593055160928352\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.2412654012441635\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 6000 | LR 0.0005999333333333129\n",
|
|
|
+ "physics loss:\t\t0.001605141155981103\n",
|
|
|
+ "observation loss:\t0.00013662428319793931\n",
|
|
|
+ "loss:\t\t\t0.0017417654391790425\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.23936264216899872\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 7000 | LR 0.0005332666666666435\n",
|
|
|
+ "physics loss:\t\t0.00190968670757966\n",
|
|
|
+ "observation loss:\t0.00012329566674995849\n",
|
|
|
+ "loss:\t\t\t0.0020329823743296185\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.23441843688488007\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 8000 | LR 0.0004665999999999755\n",
|
|
|
+ "physics loss:\t\t0.0019683097313949755\n",
|
|
|
+ "observation loss:\t0.00011957883463252119\n",
|
|
|
+ "loss:\t\t\t0.0020878885660274966\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.2345646470785141\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 9000 | LR 0.0003999333333333152\n",
|
|
|
+ "physics loss:\t\t0.0023533661719838616\n",
|
|
|
+ "observation loss:\t0.00011510529573689898\n",
|
|
|
+ "loss:\t\t\t0.0024684714677207604\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.22906526923179626\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 10000 | LR 0.00033326666666665735\n",
|
|
|
+ "physics loss:\t\t0.0025831233331882735\n",
|
|
|
+ "observation loss:\t0.00011076522614630127\n",
|
|
|
+ "loss:\t\t\t0.002693888559334575\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.22521229088306427\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 11000 | LR 0.00026659999999999987\n",
|
|
|
+ "physics loss:\t\t0.002667692420961139\n",
|
|
|
+ "observation loss:\t0.00010629412421269853\n",
|
|
|
+ "loss:\t\t\t0.0027739865451738375\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.22304710745811462\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 12000 | LR 0.0001999333333333354\n",
|
|
|
+ "physics loss:\t\t0.002831200195029232\n",
|
|
|
+ "observation loss:\t0.00015541454626586456\n",
|
|
|
+ "loss:\t\t\t0.0029866147412950968\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.2231626659631729\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 13000 | LR 0.0001332666666666689\n",
|
|
|
+ "physics loss:\t\t0.002496475121496725\n",
|
|
|
+ "observation loss:\t9.74420291549779e-05\n",
|
|
|
+ "loss:\t\t\t0.002593917150651703\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.22342386841773987\n",
|
|
|
+ "#################################\n",
|
|
|
+ "\n",
|
|
|
+ "Epoch 14000 | LR 6.660000000000055e-05\n",
|
|
|
+ "physics loss:\t\t0.0024228574550409178\n",
|
|
|
+ "observation loss:\t9.359204257759588e-05\n",
|
|
|
+ "loss:\t\t\t0.0025164494976185137\n",
|
|
|
+ "---------------------------------\n",
|
|
|
+ "beta:\t\t\t0.2237681895494461\n",
|
|
|
+ "#################################\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "dinn.train(create_animation=True, \n",
|
|
|
+ " plot_I_prediction=True, \n",
|
|
|
+ " verbose=True) "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "dinn.plot_training_graphs()"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "PINN",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.11.7"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 2
|
|
|
+}
|