{ "cells": [ { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "\n", "from src.dataset import PandemicDataset, Norms\n", "from src.problem import SIRAlphaProblem\n", "from src.dinn import DINN, Scheduler, Activation, Optimizer\n", "from src.plotter import Plotter\n", "# torch.manual_seed(18361969907809597111)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "covid_data = np.genfromtxt(f'./datasets/SIR_Paper_Germany_14.csv', delimiter=',')\n", "dataset = PandemicDataset(\"SIR_Paper_Germany_14\", \n", " ['S', 'I', 'R'], \n", " 70000000, \n", " *covid_data, \n", " norm_name=Norms.MIN_MAX) \n", "\n", "ALPHA = 0.07" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Learning Rate:\t0.001\n", "Optimizer:\tADAM\n", "Scheduler:\tPOLYNOMIAL\n", "\n" ] } ], "source": [ "plotter = Plotter()\n", "\n", "problem = SIRAlphaProblem(dataset, ALPHA)\n", "dinn = DINN(3, \n", " dataset, \n", " ['beta'], \n", " problem, \n", " plotter,\n", " hidden_size=64, \n", " hidden_layers=12, \n", " activation_layer=torch.nn.Tanh(),\n", " activation_output=Activation.POWER)\n", "dinn.configure_training(1e-3, \n", " 15000, \n", " scheduler_class=Scheduler.POLYNOMIAL, \n", " optimizer_class=Optimizer.ADAM,\n", " lambda_obs=1e1, \n", " verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch seed: 4462685050789905541\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 0 | LR 0.0009999333333333333\n", "physics loss:\t\t4.160715663590364e-05\n", "observation loss:\t9.556413235291215\n", "loss:\t\t\t9.55645484244785\n", "---------------------------------\n", "beta:\t\t\t0.6011705994606018\n", "#################################\n", "\n", "Epoch 1000 | LR 0.0009332666666666524\n", "physics loss:\t\t0.0009664472242609449\n", "observation loss:\t0.0017112362767341596\n", "loss:\t\t\t0.0026776835009951045\n", "---------------------------------\n", "beta:\t\t\t0.2550368905067444\n", "#################################\n", "\n", "Epoch 2000 | LR 0.0008665999999999833\n", "physics loss:\t\t0.001247615250807723\n", "observation loss:\t0.0040488860826972005\n", "loss:\t\t\t0.005296501333504924\n", "---------------------------------\n", "beta:\t\t\t0.253487229347229\n", "#################################\n", "\n", "Epoch 3000 | LR 0.000799933333333315\n", "physics loss:\t\t0.00044841892062854347\n", "observation loss:\t0.012716515333301339\n", "loss:\t\t\t0.013164934253929882\n", "---------------------------------\n", "beta:\t\t\t0.25267982482910156\n", "#################################\n", "\n", "Epoch 4000 | LR 0.000733266666666649\n", "physics loss:\t\t0.001045761999967263\n", "observation loss:\t0.00025607909237615716\n", "loss:\t\t\t0.00130184109234342\n", "---------------------------------\n", "beta:\t\t\t0.24764902889728546\n", "#################################\n", "\n", "Epoch 5000 | LR 0.0006665999999999824\n", "physics loss:\t\t0.0014444880255704616\n", "observation loss:\t0.00014856713535789037\n", "loss:\t\t\t0.001593055160928352\n", "---------------------------------\n", "beta:\t\t\t0.2412654012441635\n", "#################################\n", "\n", "Epoch 6000 | LR 0.0005999333333333129\n", "physics loss:\t\t0.001605141155981103\n", "observation loss:\t0.00013662428319793931\n", "loss:\t\t\t0.0017417654391790425\n", "---------------------------------\n", "beta:\t\t\t0.23936264216899872\n", "#################################\n", "\n", "Epoch 7000 | LR 0.0005332666666666435\n", "physics loss:\t\t0.00190968670757966\n", "observation loss:\t0.00012329566674995849\n", "loss:\t\t\t0.0020329823743296185\n", "---------------------------------\n", "beta:\t\t\t0.23441843688488007\n", "#################################\n", "\n", "Epoch 8000 | LR 0.0004665999999999755\n", "physics loss:\t\t0.0019683097313949755\n", "observation loss:\t0.00011957883463252119\n", "loss:\t\t\t0.0020878885660274966\n", "---------------------------------\n", "beta:\t\t\t0.2345646470785141\n", "#################################\n", "\n", "Epoch 9000 | LR 0.0003999333333333152\n", "physics loss:\t\t0.0023533661719838616\n", "observation loss:\t0.00011510529573689898\n", "loss:\t\t\t0.0024684714677207604\n", "---------------------------------\n", "beta:\t\t\t0.22906526923179626\n", "#################################\n", "\n", "Epoch 10000 | LR 0.00033326666666665735\n", "physics loss:\t\t0.0025831233331882735\n", "observation loss:\t0.00011076522614630127\n", "loss:\t\t\t0.002693888559334575\n", "---------------------------------\n", "beta:\t\t\t0.22521229088306427\n", "#################################\n", "\n", "Epoch 11000 | LR 0.00026659999999999987\n", "physics loss:\t\t0.002667692420961139\n", "observation loss:\t0.00010629412421269853\n", "loss:\t\t\t0.0027739865451738375\n", "---------------------------------\n", "beta:\t\t\t0.22304710745811462\n", "#################################\n", "\n", "Epoch 12000 | LR 0.0001999333333333354\n", "physics loss:\t\t0.002831200195029232\n", "observation loss:\t0.00015541454626586456\n", "loss:\t\t\t0.0029866147412950968\n", "---------------------------------\n", "beta:\t\t\t0.2231626659631729\n", "#################################\n", "\n", "Epoch 13000 | LR 0.0001332666666666689\n", "physics loss:\t\t0.002496475121496725\n", "observation loss:\t9.74420291549779e-05\n", "loss:\t\t\t0.002593917150651703\n", "---------------------------------\n", "beta:\t\t\t0.22342386841773987\n", "#################################\n", "\n", "Epoch 14000 | LR 6.660000000000055e-05\n", "physics loss:\t\t0.0024228574550409178\n", "observation loss:\t9.359204257759588e-05\n", "loss:\t\t\t0.0025164494976185137\n", "---------------------------------\n", "beta:\t\t\t0.2237681895494461\n", "#################################\n" ] } ], "source": [ "dinn.train(create_animation=True, \n", " plot_I_prediction=True, \n", " verbose=True) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "dinn.plot_training_graphs()" ] } ], "metadata": { "kernelspec": { "display_name": "PINN", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }