{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "\n", "from src.dataset import PandemicDataset, Norms\n", "from src.problem import ReducedSIRProblem\n", "from src.dinn import DINN, Scheduler, Activation\n", "from src.plotter import Plotter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "alpha = 1/3" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Learning Rate:\t0.001\n", "Optimizer:\tADAM\n", "Scheduler:\tPOLYNOMIAL\n", "\n", "torch seed: 9948651162532304809\n", "dIdt (min | max): -4.3066202124464326e-06 | -4.301004537410336e-06, I(min | max): 0.0019655212257585625 | 0.0019698153581950706, R_t(min | max): 0.9462157915068161 | 0.9471622551059369\n", "I_residual (min | max): 0.00515376093227558 | 0.0052576320047351445\n", "\n", "\n", "Epoch 0 | LR 0.00099995\n", "physics loss:\t\t4.064907813105731e-09\n", "observation loss:\t0.5382289321642364\n", "loss:\t\t\t0.5382289321642364\n", "---------------------------------\n", "dIdt (min | max): 0.029735210628132336 | 0.03037917883193586, I(min | max): 0.5932272512312444 | 0.6232452147125436, R_t(min | max): 0.8309403277604908 | 0.8333196479677305\n", "I_residual (min | max): 5.011096179934845 | 5.189888191104072\n", "\n", "dIdt (min | max): 0.04699786892160773 | 0.048385180183686316, I(min | max): 0.5854795313251806 | 0.6329777429673413, R_t(min | max): 1.996863750612249 | 2.020590986922116\n", "I_residual (min | max): -31.290912309605147 | -29.630473406013607\n", "\n", "\n", "Epoch 1000 | LR 0.0009499500000000102\n", "physics loss:\t\t0.13926813941205285\n", "observation loss:\t0.16752187842201924\n", "loss:\t\t\t0.16752187842201924\n", "---------------------------------\n", "dIdt (min | max): 0.17560703598428518 | 0.20268904085969552, I(min | max): 0.5197350518656663 | 0.7081941509225089, R_t(min | max): 2.6384550620405776 | 2.834613956049509\n", "I_residual (min | max): -64.32736697860197 | -42.11866512092712\n", "\n", "dIdt (min | max): 0.3005482160951942 | 0.39336065272800624, I(min | max): 0.4468832532038505 | 0.7942175840803998, R_t(min | max): 0.8456744830101712 | 0.933018778499676\n", "I_residual (min | max): 3.0355112940255 | 3.7997852719123584\n", "\n", "\n", "Epoch 2000 | LR 0.0008999500000000076\n", "physics loss:\t\t0.0019611244738854294\n", "observation loss:\t0.15231235274998953\n", "loss:\t\t\t0.15231235274998953\n", "---------------------------------\n", "dIdt (min | max): -0.1536848735413514 | 0.8572140112519264, I(min | max): 0.30179775986562163 | 0.7641454788118871, R_t(min | max): 1.0953468236426147 | 1.5376301789810896\n", "I_residual (min | max): -20.260114727890787 | -0.5937326715326567\n", "\n", "dIdt (min | max): -0.1426919352461482 | -0.02923016343265772, I(min | max): 0.45858985593713797 | 0.5332348570740386, R_t(min | max): 0.0866182904601942 | 0.11768594231978202\n", "I_residual (min | max): 20.661064865563176 | 23.33797348849647\n", "\n", "\n", "Epoch 3000 | LR 0.0008499500000000042\n", "physics loss:\t\t0.07539271839314121\n", "observation loss:\t0.18647917953013218\n", "loss:\t\t\t0.18647917953013218\n", "---------------------------------\n", "dIdt (min | max): -2.3556392604950815 | 3.6454559939011233, I(min | max): 3.2305881211006776e-05 | 1.0984531271015499, R_t(min | max): 1.9658075994244406e-06 | 0.7027194913162162\n", "I_residual (min | max): -0.08416777805219415 | 47.056998967269706\n", "\n", "dIdt (min | max): -2.459546227888495 | 6.572167123667896, I(min | max): 3.455589381978763e-05 | 1.241782212576382, R_t(min | max): 0.011134867177778285 | 2.3413964049263\n", "I_residual (min | max): -27.238923247828726 | 33.14330039264253\n", "\n", "\n", "Epoch 4000 | LR 0.0007999499999999993\n", "physics loss:\t\t0.06065713943952673\n", "observation loss:\t0.0005927743819179973\n", "loss:\t\t\t0.0005927743819179973\n", "---------------------------------\n", "dIdt (min | max): -2.3630973695544526 | 7.038932139053941, I(min | max): 2.068056294268761e-06 | 1.2486979634156938, R_t(min | max): 0.03647387921418588 | 2.5186859571263653\n", "I_residual (min | max): -32.180478066422396 | 26.48789410712128\n", "\n", "dIdt (min | max): -2.2560938742244616 | 7.245537573471665, I(min | max): 7.630110711309701e-06 | 1.2371890991747847, R_t(min | max): 0.0442799555145692 | 2.5579207257215018\n", "I_residual (min | max): -36.48192865407925 | 22.047335802905828\n", "\n", "\n", "Epoch 5000 | LR 0.0007499499999999949\n", "physics loss:\t\t0.08611380386770974\n", "observation loss:\t0.0023679101940069012\n", "loss:\t\t\t0.0023679101940069012\n", "---------------------------------\n", "dIdt (min | max): -2.198701551533304 | 7.332506489008665, I(min | max): 1.557987518997772e-05 | 1.2309771359230268, R_t(min | max): 0.05209980250893498 | 2.615410696122936\n", "I_residual (min | max): -38.869923733729834 | 20.120389743405912\n", "\n", "dIdt (min | max): -2.24998843044159 | 7.442951250821352, I(min | max): 3.080447031237471e-06 | 1.2422511863115346, R_t(min | max): 0.05885529208612539 | 2.656475625868225\n", "I_residual (min | max): -39.504625081728385 | 19.675317313300766\n", "\n", "\n", "Epoch 6000 | LR 0.0006999499999999886\n", "physics loss:\t\t0.10191596111113671\n", "observation loss:\t0.0005234928770837613\n", "loss:\t\t\t0.0005234928770837613\n", "---------------------------------\n", "dIdt (min | max): -2.2418147744610906 | 7.473376495530829, I(min | max): 1.7142058841932567e-05 | 1.2415008708345, R_t(min | max): 0.06435126966039562 | 2.698847370891272\n", "I_residual (min | max): -40.618602285760296 | 18.80886388660886\n", "\n", "dIdt (min | max): -2.2400910716096405 | 7.517381172743626, I(min | max): 4.114157521384332e-06 | 1.2417474084236915, R_t(min | max): 0.06851398960760946 | 2.7345679349041916\n", "I_residual (min | max): -41.494566500030665 | 18.10868361807657\n", "\n", "\n", "Epoch 7000 | LR 0.0006499499999999803\n", "physics loss:\t\t0.11408178467372233\n", "observation loss:\t2.4275713194266008e-05\n", "loss:\t\t\t2.4275713194266008e-05\n", "---------------------------------\n", "dIdt (min | max): -2.230403194553219 | 7.539000155520625, I(min | max): 1.2204102542746897e-06 | 1.2411265949310888, R_t(min | max): 0.07289614346359574 | 2.768566045963098\n", "I_residual (min | max): -42.33065639118309 | 17.479356774769265\n", "\n", "dIdt (min | max): -2.2174056386575103 | 7.551401607692242, I(min | max): 7.314339322030128e-07 | 1.2428053005203452, R_t(min | max): 0.07948607067533775 | 2.8049403942079607\n", "I_residual (min | max): -43.18112773531085 | 16.967911292482903\n", "\n", "\n", "Epoch 8000 | LR 0.0005999499999999715\n", "physics loss:\t\t0.1267158037177045\n", "observation loss:\t0.001140470382112762\n", "loss:\t\t\t0.001140470382112762\n", "---------------------------------\n", "dIdt (min | max): -2.228236891212873 | 7.592347430996597, I(min | max): 2.4185795897274653e-08 | 1.2445215789721829, R_t(min | max): 0.08163291580844145 | 2.824335179622949\n", "I_residual (min | max): -43.10566239824683 | 16.95350331740783\n", "\n", "dIdt (min | max): -2.2245319766079774 | 7.607948694843799, I(min | max): 1.2692663000635207e-06 | 1.243967082244822, R_t(min | max): 0.07074258883993512 | 2.8044772225532597\n", "I_residual (min | max): -43.20043217177767 | 16.841671666429264\n", "\n", "\n", "Epoch 9000 | LR 0.000549949999999962\n", "physics loss:\t\t0.1252684577759251\n", "observation loss:\t1.4261171530064027e-05\n", "loss:\t\t\t1.4261171530064027e-05\n", "---------------------------------\n", "dIdt (min | max): -2.2230286567937583 | 7.591551184654236, I(min | max): 2.355196672718801e-06 | 1.241261529632439, R_t(min | max): 0.0747454788833899 | 2.8248059988369647\n", "I_residual (min | max): -43.05238593270748 | 16.70150245345828\n", "\n", "dIdt (min | max): -2.2181494790274883 | 7.584606601158157, I(min | max): 2.199203024189078e-06 | 1.2406429621080406, R_t(min | max): 0.07164354646307913 | 2.820081555415527\n", "I_residual (min | max): -43.19962428627962 | 16.466317910479376\n", "\n", "\n", "Epoch 10000 | LR 0.0004999499999999524\n", "physics loss:\t\t0.12567281139280068\n", "observation loss:\t6.5288327456431e-05\n", "loss:\t\t\t0.1257380997202571\n", "---------------------------------\n", "dIdt (min | max): -2.2244977978334646 | 7.423023615243437, I(min | max): 1.6105586710148112e-06 | 1.2390946965226703, R_t(min | max): 0.940597352622504 | 1.346293085059088\n", "I_residual (min | max): -0.9566095449569341 | 1.5113435710459666\n", "\n", "dIdt (min | max): -2.214308594033355 | 7.443884279811755, I(min | max): 1.322075828925906e-05 | 1.2372654752483783, R_t(min | max): 0.9240679250115278 | 1.5767932481384292\n", "I_residual (min | max): -0.4079618209148981 | 0.5408993306711558\n", "\n", "\n", "Epoch 11000 | LR 0.0004499499999999541\n", "physics loss:\t\t1.0374903344387553e-05\n", "observation loss:\t0.0002900930372790617\n", "loss:\t\t\t0.00030046794062344926\n", "---------------------------------\n", "dIdt (min | max): -2.2048103295383044 | 7.550702021806501, I(min | max): 9.603051115428252e-06 | 1.238501344888391, R_t(min | max): 0.9243309370558315 | 1.6901306749807645\n", "I_residual (min | max): -0.28215162866274124 | 0.4204509832472303\n", "\n", "dIdt (min | max): -2.2075480042403797 | 7.603872123989277, I(min | max): 6.07450978407087e-10 | 1.239578027505118, R_t(min | max): 0.9268095553673028 | 1.753151517325037\n", "I_residual (min | max): -0.2587943329289 | 0.3736682627264123\n", "\n", "\n", "Epoch 12000 | LR 0.00039994999999996085\n", "physics loss:\t\t3.578199539784047e-06\n", "observation loss:\t5.472404057675049e-05\n", "loss:\t\t\t5.830224011653454e-05\n", "---------------------------------\n", "dIdt (min | max): -2.205835248070798 | 7.628314129600767, I(min | max): 3.041783728650188e-06 | 1.238847360949876, R_t(min | max): 0.9311481062789007 | 1.7893593502813019\n", "I_residual (min | max): -0.23249204688486858 | 0.3081083862875844\n", "\n", "dIdt (min | max): -2.18662476712052 | 7.578761222481262, I(min | max): 4.283968613197353e-06 | 1.2357654354108405, R_t(min | max): 0.9340264777261496 | 1.803714146416496\n", "I_residual (min | max): -0.2076371172166418 | 0.27328868070531676\n", "\n", "\n", "Epoch 13000 | LR 0.00034994999999996923\n", "physics loss:\t\t1.8832011925846729e-06\n", "observation loss:\t8.456478776864589e-05\n", "loss:\t\t\t8.644798896123056e-05\n", "---------------------------------\n", "dIdt (min | max): -2.179952934908215 | 7.6077519888058305, I(min | max): 5.651204785378505e-07 | 1.2325351459310383, R_t(min | max): 0.937997249635206 | 1.8167351987076472\n", "I_residual (min | max): -0.1660933356930101 | 0.2229823695947062\n", "\n", "dIdt (min | max): -2.210042197490111 | 7.590020227711648, I(min | max): 4.428153274647212e-06 | 1.2392759677732244, R_t(min | max): 0.9385707998089856 | 1.7918408104470132\n", "I_residual (min | max): -0.16755115754249061 | 0.16383575596832256\n", "\n", "\n", "Epoch 14000 | LR 0.00029994999999997913\n", "physics loss:\t\t1.3762284077226225e-06\n", "observation loss:\t0.0005259605573378371\n", "loss:\t\t\t0.0005273367857455598\n", "---------------------------------\n", "dIdt (min | max): -2.2051773993007373 | 7.643406911520287, I(min | max): 1.5833532565068253e-06 | 1.2385615758052069, R_t(min | max): 0.9395791809500373 | 1.7867935604626837\n", "I_residual (min | max): -0.16234917395222492 | 0.16399445054973882\n", "\n", "dIdt (min | max): -2.205072554606886 | 7.643023730503046, I(min | max): 3.2671426177278973e-06 | 1.2385339808421492, R_t(min | max): 0.9410153451505074 | 1.7716399386272883\n", "I_residual (min | max): -0.145981606071391 | 0.13828644586250682\n", "\n", "\n", "Epoch 15000 | LR 0.0002499499999999903\n", "physics loss:\t\t8.217994584763831e-07\n", "observation loss:\t7.000775491735788e-07\n", "loss:\t\t\t1.521877007649962e-06\n", "---------------------------------\n", "dIdt (min | max): -2.205418490590091 | 7.642543703746924, I(min | max): 1.6979653121897142e-06 | 1.2385737813676911, R_t(min | max): 0.9420944608218775 | 1.758883613641956\n", "I_residual (min | max): -0.13061930290693935 | 0.12318431082114945\n", "\n", "dIdt (min | max): -2.2061518276987044 | 7.6391230421941145, I(min | max): 2.247346062717037e-09 | 1.2386146439058763, R_t(min | max): 0.9432736462265581 | 1.7442180773151108\n", "I_residual (min | max): -0.11734928204305595 | 0.11063335839698762\n", "\n", "\n", "Epoch 16000 | LR 0.0001999499999999929\n", "physics loss:\t\t5.129720036526635e-07\n", "observation loss:\t5.002273247783877e-07\n", "loss:\t\t\t1.0131993284310512e-06\n", "---------------------------------\n", "dIdt (min | max): -2.2063585466057702 | 7.639161886370857, I(min | max): 1.5750585774521042e-07 | 1.2386379942303307, R_t(min | max): 0.94395061036181 | 1.7378831682492688\n", "I_residual (min | max): -0.10411140585319223 | 0.10095863128125426\n", "\n", "dIdt (min | max): -2.2078374401608016 | 7.631559064997418, I(min | max): 3.6293508345365544e-06 | 1.2386743468780423, R_t(min | max): 0.945272930362254 | 1.7203316936145399\n", "I_residual (min | max): -0.09478741558585946 | 0.0853070270949825\n", "\n", "\n", "Epoch 17000 | LR 0.00014994999999999307\n", "physics loss:\t\t3.179241116330299e-07\n", "observation loss:\t3.677285418033541e-07\n", "loss:\t\t\t6.856526534363839e-07\n", "---------------------------------\n", "dIdt (min | max): -2.207295828306087 | 7.634140691559878, I(min | max): 3.92869862858615e-05 | 1.2387093732913854, R_t(min | max): 0.9453598584163423 | 1.721340341432608\n", "I_residual (min | max): -0.08371848205797883 | 0.08258404776904094\n", "\n", "dIdt (min | max): -2.207882463833812 | 7.631950354523724, I(min | max): 8.869285619442571e-05 | 1.2387526262881607, R_t(min | max): 0.9458201800897683 | 1.720373284719301\n", "I_residual (min | max): -0.07665967714440036 | 0.07747114330001725\n", "\n", "\n", "Epoch 18000 | LR 9.99499999999938e-05\n", "physics loss:\t\t2.1873778595545899e-07\n", "observation loss:\t2.6589031389273194e-07\n", "loss:\t\t\t4.846280998481909e-07\n", "---------------------------------\n", "dIdt (min | max): -2.2084692420503416 | 7.629564266404486, I(min | max): 0.00014522856013480037 | 1.238794022501466, R_t(min | max): 0.9461958466302036 | 1.7198241975511905\n", "I_residual (min | max): -0.07045001528330941 | 0.07231426003020136\n", "\n", "dIdt (min | max): -2.2089616211851535 | 7.627448219136568, I(min | max): 0.00019848300634001115 | 1.2388325003712453, R_t(min | max): 0.946483096329306 | 1.7198679712174112\n", "I_residual (min | max): -0.0654118791001439 | 0.06802954938183259\n", "\n", "\n", "Epoch 19000 | LR 4.9949999999997616e-05\n", "physics loss:\t\t1.6233895191831296e-07\n", "observation loss:\t1.982327980794814e-07\n", "loss:\t\t\t3.6057174999779433e-07\n", "---------------------------------\n", "dIdt (min | max): -2.2092989662996843 | 7.625929547473788, I(min | max): 0.00023808052155073478 | 1.238858506442341, R_t(min | max): 0.946676205508993 | 1.7202163043848486\n", "I_residual (min | max): -0.062211480635420036 | 0.06508937630454081\n", "\n" ] } ], "source": [ "plotter = Plotter()\n", "covid_data = np.genfromtxt('./datasets/I_data.csv', delimiter=',')\n", "dataset = PandemicDataset('synth_sir', ['I'], 7.6e6, *covid_data, norm_name=Norms.CONSTANT, use_scaled_time=True)\n", "\n", "problem = ReducedSIRProblem(dataset, alpha)\n", "\n", "dinn = DINN(2, \n", " dataset, \n", " [], \n", " problem, \n", " plotter, \n", " state_variables=['R_t'], \n", " hidden_size=100, \n", " hidden_layers=4, \n", " activation_layer=torch.nn.Tanh(), \n", " activation_output=Activation.POWER,\n", " use_glorot_initialization=True)\n", "dinn.configure_training(1e-3, \n", " 20000, \n", " lambda_physics=1e-6,\n", " scheduler_class=Scheduler.POLYNOMIAL, \n", " verbose=True)\n", "\n", "dinn.train(create_animation=True, verbose=True, do_split_training=True)\n", "dinn.plot_training_graphs()\n", "dinn.plot_state_variables()\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "synth_r_t = np.zeros(150, dtype=np.float64)\n", "for i, time in enumerate(range(150)):\n", " synth_r_t[i] = -np.tanh(time * 0.05 - 2) * 0.4 + 1.35\n", "r_t = dinn.get_output(1).detach().cpu().numpy()\n", "plotter.plot(dataset.t_raw.detach().cpu().numpy(), [r_t, synth_r_t], [\"pred\", \"true\"], \"test\", \"R_t\", (12, 6))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "PINN", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }