{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4852849e", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, \"../\")\n", "import time\n", "import numpy as np\n", "from pinn import *\n", "from grad_stats import *\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch.autograd import grad\n", "from torch.optim import Adam\n", "from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR\n", "\n", "from tqdm import tqdm_notebook as tqdm \n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "id": "697cf9bc", "metadata": {}, "outputs": [], "source": [ "# experiment setup\n", "lx=0\n", "ly=0\n", "rx=1\n", "ry=1\n", "\n", "seed = 1\n", "#omega = 6*np.pi\n", "\n", "print(\"seed\", seed)\n", "\n", "def kg_equation(x,y):\n", " return x*np.cos(5*np.pi*y) + (x*y)**3\n", "\n", "x = np.linspace(lx,rx, 100)\n", "y = np.linspace(ly, ry, 100)\n", "\n", "\n", "xx,yy = np.meshgrid(x,y)\n", "u_sol = kg_equation(xx,yy)\n", "\n", "X = np.vstack([xx.ravel(), yy.ravel()]).T\n", "plt.imshow(u_sol, cmap=\"twilight\", origin=\"lower\",vmin=-1.5,vmax=1.5)\n", "plt.colorbar()\n", "\n", "\n", "def u(x):\n", " return x[:, 0:1] * np.cos(5 * np.pi * x[:, 1:2]) + (x[:, 1:2] * x[:, 0:1])**3\n", "\n", "def u_tt(x):\n", " return - 25 * np.pi**2 * x[:, 0:1] * np.cos(5 * np.pi * x[:, 1:2]) + 6 * x[:,1:2] * x[:,0:1]**3\n", "\n", "def u_xx(x):\n", " return np.zeros((x.shape[0], 1)) + 6 * x[:,0:1] * x[:,1:2]**3\n", "\n", "def f(x, alpha=-1.0, beta=0.0, gamma=1.0, k=3.0):\n", " return u_tt(x) + alpha * u_xx(x) + beta * u(x) + gamma * u(x)**k\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9c631a98", "metadata": {}, "outputs": [], "source": [ "def sampler(num_r=2500, num_b=100,lx=0,rx=1,ly=0,ry=1,seed=1):\n", " # generate training data\n", " x = np.linspace(lx, rx, 100)\n", " y = np.linspace(ly, ry, 100)\n", " xb = np.linspace(lx,rx,num_b)\n", " yb = np.linspace(ly,ry,num_b)\n", " \n", " xx,yy = np.meshgrid(x,y)\n", " \n", " X = np.vstack([xx.ravel(), yy.ravel()]).T\n", "\n", "\n", " # X boundaries\n", " lb = lx*np.ones((yb.shape))\n", " rb = rx*np.ones((yb.shape))\n", " Xlb = np.vstack((lb,yb)).T\n", " Xrb = np.vstack((rb,yb)).T\n", " UXlb = kg_equation(Xlb[:,0:1],Xlb[:,1:2])\n", " UXrb = kg_equation(Xrb[:,0:1],Xrb[:,1:2])\n", "\n", " # Y boundaries\n", " lb = ly*np.ones((xb.shape))\n", " rb = ry*np.ones((xb.shape))\n", " Ylb = np.vstack((xb,lb)).T\n", " Yrb = np.vstack((xb,rb)).T\n", " UYlb = kg_equation(Ylb[:,0:1],Ylb[:,1:2])\n", " UYrb = kg_equation(Yrb[:,0:1],Yrb[:,1:2])\n", "\n", " seedc = seed\n", " np.random.seed(seedc)\n", " torch.manual_seed(seedc)\n", "\n", " # training tensors\n", " idxs = np.random.choice(xx.size, num_r, replace=False)\n", " X_train = torch.tensor(X[idxs], dtype=torch.float32, requires_grad=True,device=device)\n", " X_rb = torch.tensor(Xrb, dtype=torch.float32, device=device)\n", " X_lb = torch.tensor(Xlb, dtype=torch.float32, device=device)\n", " Y_rb = torch.tensor(Yrb, dtype=torch.float32, device=device)\n", " Y_lb = torch.tensor(Ylb, dtype=torch.float32, requires_grad=True,device=device)\n", " # compute mean and std of training data\n", " X_mean = torch.tensor(np.mean(np.concatenate([X[idxs], Xrb, Xlb, Yrb, Ylb], 0), axis=0, keepdims=True), dtype=torch.float32, device=device)\n", " X_std = torch.tensor(np.std(np.concatenate([X[idxs], Xrb, Xlb, Yrb, Ylb], 0), axis=0, keepdims=True), dtype=torch.float32, device=device)\n", " U_Train= torch.tensor(f(X[idxs]), dtype=torch.float32, requires_grad=True,device=device)\n", " U_X_rb = torch.tensor(UXrb, dtype=torch.float32, device=device).reshape(num_b,1)\n", " U_X_lb = torch.tensor(UXlb, dtype=torch.float32, device=device).reshape(num_b,1)\n", " U_Y_rb = torch.tensor(UYrb, dtype=torch.float32, device=device).reshape(num_b,1)\n", " U_Y_lb = torch.tensor(UYlb, dtype=torch.float32, requires_grad=True, device=device).reshape(num_b,1)\n", " \n", "\n", " return X_train, X_lb, X_rb, Y_lb, Y_rb, U_Train,U_X_lb, U_X_rb, U_Y_lb, U_Y_rb, X_mean, X_std\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ff298331", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "def KG_res(uhat, data):\n", " poly = torch.ones_like(uhat)\n", " \n", " du = grad(outputs=uhat, inputs=data, \n", " grad_outputs=torch.ones_like(uhat), create_graph=True)[0]\n", " \n", " dudx = du[:,0:1]\n", " dudy = du[:,1:2]\n", " \n", " dudxx = grad(outputs=dudx, inputs=data, \n", " grad_outputs=torch.ones_like(uhat), create_graph=True)[0][:,0:1]\n", " dudyy = grad(outputs=dudy, inputs=data, \n", " grad_outputs=torch.ones_like(uhat), create_graph=True)[0][:,1:2]\n", " \n", " xin = data[:,0:1]\n", " yin = data[:,1:2]\n", " \n", " \n", " \n", " residual = dudyy - dudxx + uhat**3\n", " \n", " return residual\n", "def u_t(uhat, data):\n", " poly = torch.ones_like(uhat)\n", " \n", " du = grad(outputs=uhat, inputs=data, \n", " grad_outputs=torch.ones_like(uhat), create_graph=True)[0]\n", " \n", " \n", " dudy = du[:,1:2]\n", " return dudy - 0*uhat\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "c9ab58f4", "metadata": {}, "outputs": [], "source": [ "def plot_function(lx,rx,ly,ry,u_sol,out,method,extras=None):\n", " methods=[\"W1 (uniform)\",\"W2 (max_by_mean)\",\"W3 (std)\",\"W4 (mean+std)\",\"W5 (mean*std)\",\"W6 (kurtosis)\"]\n", " plt.subplot(1,3,1)\n", " plt.imshow(u_sol, cmap=\"twilight\", origin=\"lower\",vmin=-1.0,vmax=1.5)\n", " plt.xticks(np.arange(0,101,50),np.linspace(lx,rx,3),fontsize=12)\n", " plt.yticks(np.arange(0,101,50),np.linspace(ly,ry,3),fontsize=12)\n", " plt.xlabel(r\"$x$\",fontsize=15)\n", " plt.ylabel(r\"$y$\",fontsize=15)\n", " plt.title(\"Ground Truth\",fontsize=18)\n", " plt.colorbar(fraction=0.046, pad=0.04)\n", "\n", " plt.subplot(1,3,2)\n", " plt.imshow(out, cmap=\"twilight\", origin=\"lower\",vmin=-1.0,vmax=1.5)\n", " plt.xticks(np.arange(0,101,50),np.linspace(lx,rx,3),fontsize=12)\n", " plt.yticks(np.arange(0,101,50),np.linspace(ly,ry,3),fontsize=12)\n", " plt.xlabel(r\"$x$\",fontsize=15)\n", " plt.ylabel(r\"$y$\",fontsize=15)\n", " plt.title(\"Prediction\",fontsize=18)\n", " plt.colorbar(fraction=0.046, pad=0.04)\n", "\n", " plt.subplot(1,3,3)\n", " plt.imshow(np.abs(out-u_sol)/np.max(np.abs(u_sol)), cmap=\"nipy_spectral\", origin=\"lower\",vmin=0,vmax=0.2)\n", " plt.xticks(np.arange(0,101,50),np.linspace(lx,rx,3),fontsize=12)\n", " plt.yticks(np.arange(0,101,50),np.linspace(ly,ry,3),fontsize=12)\n", " plt.xlabel(r\"$x$\",fontsize=15)\n", " plt.ylabel(r\"$y$\",fontsize=15)\n", " plt.title(\"Point-wise Error\",fontsize=18)\n", " plt.colorbar(fraction=0.046, pad=0.04)\n", " \n", " \n", " \n", " plt.gcf().set_size_inches(15,5)\n", " plt.tight_layout()\n", " plt.suptitle(\"Klein-Gordon Equation using PINN_{}\".format(methods[method]),fontsize=25)\n", " plt.savefig(extras+\"KGEqn_{}_based_Tanh\".format(methods[method]),dpi=800)\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "2c9a1e02", "metadata": {}, "outputs": [], "source": [ "losses_boundary_global=[]\n", "losses_residual_global=[]\n", "lambdas_global=[]\n", "list_of_l2_Errors=[]\n", "for i in range(6):\n", " for j in range(5):\n", " sets = [[300,25],[500,25],[600,30],[750,50],[1000,100]]\n", " mm = 5\n", " alpha_ann = 0.5\n", " n_epochs = 40001\n", " method = i\n", " num_r= sets[j][0]\n", " num_b= sets[j][1]\n", " extras=str(num_r)+ \"+\"+ str(num_b)\n", " print(\"#######Training with#####\\n\",extras)\n", " #print(extras)\n", " X_train, X_lb, X_rb, Y_lb, Y_rb, U_Train,U_X_lb, U_X_rb, U_Y_lb, U_Y_rb, X_mean, X_std= sampler(num_r,num_b,lx,rx,ly,ry,seed=1)\n", " net = PINN(sizes=[2,50,50,50,50,50,1], mean=X_mean, std=X_std, seed=1, activation=torch.nn.Tanh()).to(device)\n", " lambd = 1\n", " lambds = [];\n", " losses_boundary = [];\n", " losses_residual = [];\n", " params = [{'params': net.parameters(), 'lr': 1e-5}]\n", " #milestones = [[10000,20000,30000]]\n", " optimizer = Adam(params)\n", " #scheduler = MultiStepLR(optimizer, milestones[0], gamma=0.1)\n", " print(\"training with shape\", X_train.size())\n", " start_time = time.time()\n", " for epoch in range(n_epochs): \n", " uhat = net(X_train)\n", " res = KG_res(uhat, X_train)\n", " l_reg = torch.mean((res-U_Train)**2)\n", " predl = net(X_lb)\n", " predr = net(X_rb)\n", " l_bc = torch.mean((predl - U_X_lb)**2, dim=0)\n", " l_bc += torch.mean((predr - U_X_rb)**2, dim=0)\n", " predl = net(Y_lb)\n", " #predr = net(Y_rb)\n", " l_bc += torch.mean((predl - U_Y_lb)**2, dim=0)\n", " #l_bc += torch.mean((predr - U_Y_rb)**2, dim=0) \n", " gpreds=u_t(predl,Y_lb)\n", " l_bc += torch.mean((gpreds)**2)\n", " with torch.no_grad():\n", " if epoch % mm == 0:\n", " stdr,kurtr=loss_grad_stats(l_reg, net)\n", " stdb,kurtb=loss_grad_stats(l_bc, net)\n", " maxr,meanr=loss_grad_max_mean(l_reg, net)\n", " maxb,meanb=loss_grad_max_mean(l_bc, net,lambg=lambd)\n", " if method == 2:\n", " # inverse dirichlet\n", " lamb_hat = stdr/stdb\n", " lambd = (1-alpha_ann)*lambd + alpha_ann*lamb_hat\n", " if lambd < 1:\n", " lambd = torch.tensor(1.0, dtype=torch.float32, device=device)\n", " \n", " elif method == 1:\n", " # max/avg\n", " lamb_hat = maxr/meanb\n", " lambd = (1-alpha_ann)*lambd + alpha_ann*lamb_hat \n", " if lambd < 1:\n", " lambd = torch.tensor(1.0, dtype=torch.float32, device=device)\n", " \n", " elif method==3:\n", " # mean + std weighing\n", " covr= stdr + maxr\n", " covb= stdb + meanb\n", " lamb_hat = covr/covb\n", " lambd = (1-alpha_ann)*lambd + alpha_ann*lamb_hat\n", " if lambd < 1:\n", " lambd = torch.tensor(1.0, dtype=torch.float32, device=device)\n", " \n", " elif method == 5:\n", " # kurtosis based weighing\n", " covr= stdr/kurtr\n", " covb= stdb/kurtb\n", " lamb_hat = covr/covb\n", " lambd = (1-alpha_ann)*lambd + alpha_ann*lamb_hat\n", " if lambd < 1:\n", " lambd = torch.tensor(1.0, dtype=torch.float32, device=device)\n", " \n", " elif method == 4:\n", " # mean * std weighing\n", " covr= stdr * meanr\n", " covb= stdb * meanb\n", " lamb_hat = covr/covb\n", " lambd = (1-alpha_ann)*lambd + alpha_ann*lamb_hat\n", " if lambd < 1:\n", " lambd = torch.tensor(1.0, dtype=torch.float32, device=device)\n", " \n", " else:\n", " # uniform weighing \n", " lambd = 1;\n", " if(method == 0):\n", " loss = l_reg + l_bc\n", " elif(method == 1 or method == 2 or method==3 or method==4 or method == 5):\n", " loss = l_reg + lambd*l_bc\n", " if epoch%100==0:\n", " losses_boundary.append(l_bc.item())\n", " losses_residual.append(l_reg.item())\n", " if method !=0:\n", " lambds.append(lambd.item())\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " #scheduler.step()\n", " print(\"epoch {}/{}, loss={:.10f}, lambda={:.4f}, lr={:,.5f}\\t\\t\\t\"\n", " .format(epoch+1, n_epochs, loss.item(), lambd, optimizer.param_groups[0]['lr']), end=\"\\r\")\n", " elapsed_time = time.time() - start_time\n", " #print('CPU time = ',elapsed_time)\n", " inp = torch.tensor(X, dtype=torch.float32, device=device)\n", " out = net(inp).cpu().data.numpy().reshape(u_sol.shape)\n", " print(\"\\n.....\\n\")\n", " print(\"Method:\",method)\n", " print(\"pred rel. l2-error = {:e}\".format(np.linalg.norm(out.reshape(-1)-u_sol.reshape(-1))/np.linalg.norm(out.reshape(-1))))\n", " print(\"\\n.....\\n\")\n", " if j==0:\n", " plot_function(lx,rx,ly,ry,u_sol,out,method,extras=extras)\n", " list_of_l2_Errors.append(np.linalg.norm(out.reshape(-1)-u_sol.reshape(-1))/np.linalg.norm(out.reshape(-1)))\n", " losses_boundary_global.append(losses_boundary)\n", " losses_residual_global.append(losses_residual)" ] }, { "cell_type": "code", "execution_count": null, "id": "c3835f0a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c538ca05", "metadata": {}, "outputs": [], "source": [ "def split(list_a, chunk_size):\n", "\n", " for i in range(0, len(list_a), chunk_size):\n", " yield list_a[i:i + chunk_size]\n", "List = list(split(list_of_l2_Errors,5))\n", "for i in List:\n", " arr = np.array(i)\n", " #print(\"\\n\")\n", " print(arr)" ] }, { "cell_type": "code", "execution_count": null, "id": "c6704972", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }