{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "59a27db9", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import time\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange\n", "from jax import jvp, value_and_grad\n", "from flax import linen as nn\n", "from typing import Sequence\n", "from functools import partial\n", "import numpy as np " ] }, { "cell_type": "code", "execution_count": 2, "id": "fa985589", "metadata": {}, "outputs": [], "source": [ "class CPPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", "\n", " xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])\n", " return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])\n", " \n", "class TTPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_uniform()\n", " for i,X in enumerate(inputs):\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " if i==0:\n", " X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)\n", " \n", " else:\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " outputs += [X]\n", " \n", " #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])\n", " #print(mid.shape)\n", " #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])\n", " #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])\n", " return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])\n", "\n", " \n", "class TuckerPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " def setup(self):\n", " # Initialize learnable parameters\n", " #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))\n", " self.core = self.param(\"core\",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))\n", "\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", " #mid = jnp.einsum(\"fx,fy->fxy\",outputs[0],outputs[1])\n", " return jnp.einsum(\"klm,kx,ly,mz->xyz\",self.core,outputs[0],outputs[1],outputs[-1])\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "a5bdd11a", "metadata": {}, "outputs": [], "source": [ "def hvp_fwdfwd(f, primals, tangents, return_primals=False):\n", " g = lambda primals: jvp(f, (primals,), tangents)[1]\n", " primals_out, tangents_out = jvp(g, primals, tangents)\n", " if return_primals:\n", " return primals_out, tangents_out\n", " else:\n", " return tangents_out" ] }, { "cell_type": "code", "execution_count": 4, "id": "bbf6a3d0", "metadata": {}, "outputs": [], "source": [ "def relative_l2(u, u_gt):\n", " return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)" ] }, { "cell_type": "code", "execution_count": 5, "id": "983f566f", "metadata": {}, "outputs": [], "source": [ "def loss_flowmixer(apply_fn, *train_data):\n", " def residual_loss(params, t, x, y, a, b):\n", " # tangent vector dx/dx\n", " v_t = jnp.ones(t.shape)\n", " v_x = jnp.ones(x.shape)\n", " v_y = jnp.ones(y.shape)\n", " # 1st derivatives of u\n", " ut = jvp(lambda t: apply_fn(params, t, x, y), (t,), (v_t,))[1]\n", " ux = jvp(lambda x: apply_fn(params, t, x, y), (x,), (v_x,))[1]\n", " uy = jvp(lambda y: apply_fn(params, t, x, y), (y,), (v_y,))[1]\n", " return jnp.mean((ut + a*ux + b*uy)**2)\n", "\n", " def initial_loss(params, t, x, y, u):\n", " return jnp.mean((apply_fn(params, t, x, y) - u)**2)\n", "\n", " def boundary_loss(params, t, x, y, u):\n", " loss = 0.\n", " for i in range(4):\n", " loss += jnp.mean((apply_fn(params, t[i], x[i], y[i]) - u[i])**2)\n", " return loss\n", "\n", " # unpack data\n", " print(\"Received\",len(train_data))\n", "\n", " tc, xc, yc, ti, xi, yi, ui, tb, xb, yb, ub, a, b = train_data\n", "\n", " # isolate loss func from redundant arguments\n", " loss_fn = lambda params: 10*residual_loss(params, tc, xc, yc, a, b) + \\\n", " initial_loss(params, ti, xi, yi, ui) + \\\n", " boundary_loss(params, tb, xb, yb, ub)\n", "\n", " #loss, gradient = jax.value_and_grad(loss_fn)(params)\n", "\n", " return loss_fn" ] }, { "cell_type": "code", "execution_count": 6, "id": "41373518", "metadata": {}, "outputs": [], "source": [ "@partial(jax.jit, static_argnums=(0,))\n", "def update_model(optim, gradient, params, state):\n", " updates, state = optim.update(gradient, state)\n", " params = optax.apply_updates(params, updates)\n", " return params, state" ] }, { "cell_type": "code", "execution_count": 7, "id": "861f631d", "metadata": {}, "outputs": [], "source": [ "def flow_mixing3d_exact_u(t, x, y, omega):\n", " return -jnp.tanh((y/2)*jnp.cos(omega*t) - (x/2)*jnp.sin(omega*t))\n", "\n", "\n", "# 3d time-dependent flow-mixing parameters\n", "def flow_mixing3d_params(t, x, y, v_max = 0.385,require_ab=False):\n", " \n", " # t, x, y must be meshgrid\n", " r = jnp.sqrt(x**2 + y**2)\n", " v_t = ((1/jnp.cosh(r))**2) * jnp.tanh(r)\n", " omega = (1/r)*(v_t/v_max)\n", " a, b = None, None\n", " if require_ab:\n", " a = -(v_t/v_max)*(y/r)\n", " b = (v_t/v_max)*(x/r)\n", " return omega, a, b\n", "\n", "\n", "def train_generator_flow_mixing3d(nc,key):\n", " v_max = 0.385\n", " keys = jax.random.split(key, 3)\n", " # collocation points\n", " tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=4.)\n", " xc = jax.random.uniform(keys[1], (nc, 1), minval=-4., maxval=4.)\n", " yc = jax.random.uniform(keys[2], (nc, 1), minval=-4., maxval=4.)\n", " tc_mesh, xc_mesh, yc_mesh = jnp.meshgrid(tc.ravel(), xc.ravel(), yc.ravel(), indexing='ij')\n", "\n", " _, a, b = flow_mixing3d_params(tc_mesh, xc_mesh, yc_mesh, v_max=0.385, require_ab=True)\n", "\n", " # initial points\n", " ti = jnp.zeros((1, 1))\n", " xi = xc\n", " yi = yc\n", " ti_mesh, xi_mesh, yi_mesh = jnp.meshgrid(ti.ravel(), xi.ravel(), yi.ravel(), indexing='ij')\n", " omega_i, _, _ = flow_mixing3d_params(ti_mesh, xi_mesh, yi_mesh, v_max=0.385)\n", " ui = flow_mixing3d_exact_u(ti_mesh, xi_mesh, yi_mesh, omega_i)\n", " # boundary points (hard-coded)\n", " tb = [tc, tc, tc, tc]\n", " xb = [jnp.array([[-4.]]), jnp.array([[4.]]), xc, xc]\n", " yb = [yc, yc, jnp.array([[-4.]]), jnp.array([[4.]])]\n", " ub = []\n", " for i in range(4):\n", " tb_mesh, xb_mesh, yb_mesh = jnp.meshgrid(tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), indexing='ij')\n", " omega_b, _, _ = flow_mixing3d_params(tb_mesh, xb_mesh, yb_mesh, v_max=0.385)\n", " ub += [flow_mixing3d_exact_u(tb_mesh, xb_mesh, yb_mesh, omega_b)]\n", " return tc, xc, yc, ti, xi, yi, ui, tb, xb, yb, ub, a, b" ] }, { "cell_type": "code", "execution_count": 8, "id": "24d0dbc3", "metadata": {}, "outputs": [], "source": [ "def test_generator_flow_mixing3d(nc_test):\n", " v_max = 0.385\n", " t = jnp.linspace(0, 4, nc_test)\n", " x = jnp.linspace(-4, 4, nc_test)\n", " y = jnp.linspace(-4, 4, nc_test)\n", " t = jax.lax.stop_gradient(t)\n", " x = jax.lax.stop_gradient(x)\n", " y = jax.lax.stop_gradient(y)\n", " tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')\n", "\n", " omega, _, _ = flow_mixing3d_params(tm, xm, ym, v_max)\n", " u_gt = flow_mixing3d_exact_u(tm, xm, ym, omega)\n", " t = t.reshape(-1, 1)\n", " x = x.reshape(-1, 1)\n", " y = y.reshape(-1, 1)\n", " return t, x, y,tm,xm,ym ,u_gt" ] }, { "cell_type": "code", "execution_count": 9, "id": "586fb65d", "metadata": {}, "outputs": [], "source": [ "def plotter(xm,ym,zm,u):\n", " fig = plt.figure(figsize=(6, 6))\n", " ax = fig.add_subplot(111, projection='3d')\n", " im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='magma')\n", " ax.set_title('U(t, x, y)', fontsize=20)\n", " ax.set_xlabel('t', fontsize=18, labelpad=10)\n", " ax.set_ylabel('x', fontsize=18, labelpad=10)\n", " ax.set_zlabel('y', fontsize=18, labelpad=10)\n", " fig.colorbar(im,ax=ax)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "016a8901", "metadata": {}, "outputs": [], "source": [ "def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):\n", " # force jax to use one device\n", " os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", " os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n", "\n", " # random key\n", " key = jax.random.PRNGKey(SEED)\n", " key, subkey = jax.random.split(key, 2)\n", "\n", " # feature sizes\n", " feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))\n", "\n", " #model = RBFPINN(FEATURES,linear)\n", " # make & init model\n", " if mode == \"CPPINN\":\n", " model = CPPINN(feat_sizes)\n", " elif mode == \"TTPINN\":\n", " model = TTPINN(feat_sizes)\n", " elif mode == \"TuckerPINN\":\n", " model = TuckerPINN(feat_sizes)\n", " params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))\n", " # optimizer\n", " optim = optax.adam(LR)\n", " state = optim.init(params)\n", "\n", " # dataset\n", " key, subkey = jax.random.split(key, 2)\n", " train_data = train_generator_flow_mixing3d(NC, subkey)\n", " print(len(train_data))\n", " #x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)\n", " t,x,y,tm,xm,ym, u_gt = test_generator_flow_mixing3d(NC_TEST)\n", " #print(t,x,y)\n", " logger =[]\n", "\n", " # forward & loss function\n", " apply_fn = jax.jit(model.apply)\n", " #print(len(*train_data))\n", " loss_fn = loss_flowmixer(apply_fn, *train_data)\n", "\n", " @jax.jit\n", " def train_one_step(params, state):\n", " # compute loss and gradient\n", " loss, gradient = value_and_grad(loss_fn)(params)\n", " # update state\n", " params, state = update_model(optim, gradient, params, state)\n", " return loss, params, state\n", " \n", " start = time.time()\n", " for e in trange(1, EPOCHS+1):\n", " # single run\n", " loss, params, state = train_one_step(params, state)\n", " if e % LOG_ITER == 0 or e == 1:\n", " u = apply_fn(params, t,x,y)\n", " error = relative_l2(u, u_gt)\n", " print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n", " logger.append([e,loss,error])\n", " #print('Solution:')\n", " #u = apply_fn(params,t,x,y)\n", " #plotter(tm,xm,ym,u)\n", " \n", " end = time.time()\n", " print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n", "\n", " #print('Solution:')\n", " #u = apply_fn(params,t,x,y)\n", " #plotter(tm,xm,ym,u)\n", " return logger" ] }, { "cell_type": "code", "execution_count": null, "id": "f2400611", "metadata": {}, "outputs": [], "source": [ "num = 128\n", "for model in [\"TuckerPINN\"]: \n", " for rank in [64,128]:\n", " for run in range(10):\n", " logs = main(mode=model,NC=num, NI=num, NB=num, NC_TEST=128, SEED=444+run, LR=1e-3, EPOCHS=80000, N_LAYERS=4, FEATURES=rank, LOG_ITER=5000)\n", " logs = np.array(logs)\n", " if np.min(logs[:,2])< 0.1:\n", " np.savetxt(\"Rank{}/model_{}_run_{}\".format(rank,model,run),logs)" ] }, { "cell_type": "code", "execution_count": null, "id": "a9691b60", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "4d3a31d7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-06-24 10:40:44.803880: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] } ], "source": [ "t, x, y,tm,xm,ym ,u_gt= test_generator_flow_mixing3d(100)" ] }, { "cell_type": "code", "execution_count": null, "id": "fc360bbb", "metadata": {}, "outputs": [], "source": [ "fig,ax = plt.subplots(2,2,figsize=(6,6))\n", "im=ax[0,0].imshow(u_gt[0,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[0,0].set_title(\"t=0\",fontsize=8)\n", "ax[0,0].tick_params(axis='x', labelsize=8)\n", "ax[0,0].tick_params(axis='y', labelsize=8)\n", "\n", "ax[0,1].imshow(u_gt[50,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[0,1].set_title(\"t=2\",fontsize=8)\n", "ax[0,1].tick_params(axis='x', labelsize=8)\n", "ax[0,1].tick_params(axis='y', labelsize=8)\n", "\n", "ax[1,0].imshow(u_gt[75,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[1,0].set_title(\"t=3\",fontsize=8)\n", "ax[1,0].tick_params(axis='x', labelsize=8)\n", "ax[1,0].tick_params(axis='y', labelsize=8)\n", "\n", "ax[1,1].imshow(u_gt[-1,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[1,1].set_title(\"t=4\",fontsize=8)\n", "ax[1,1].tick_params(axis='x', labelsize=8)\n", "ax[1,1].tick_params(axis='y', labelsize=8)\n", "\n", "fig.colorbar(im, ax=ax.ravel().tolist(), shrink=0.8)" ] }, { "cell_type": "code", "execution_count": null, "id": "50482884", "metadata": {}, "outputs": [], "source": [ "from celluloid import Camera" ] }, { "cell_type": "code", "execution_count": null, "id": "b370cd89", "metadata": {}, "outputs": [], "source": [ "fig = plt.figure()\n", "camera = Camera(fig)" ] }, { "cell_type": "code", "execution_count": 11, "id": "d9d7fe49", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'camera' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[11], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(u_gt[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:,:]\u001b[38;5;241m.\u001b[39mT,cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmagma\u001b[39m\u001b[38;5;124m'\u001b[39m, extent\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m4\u001b[39m])\n\u001b[1;32m 8\u001b[0m plt\u001b[38;5;241m.\u001b[39msavefig(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msnaps/snap20\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(i))\n\u001b[0;32m----> 9\u001b[0m camera\u001b[38;5;241m.\u001b[39msnap()\n", "\u001b[0;31mNameError\u001b[0m: name 'camera' is not defined" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for i in range(20):\n", " k = i*5 \n", " plt.imshow(u_gt[k,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", " #plt.show()\n", " plt.savefig(\"snaps/snap{}\".format(i))\n", " #camera.snap()\n", "plt.imshow(u_gt[-1,:,:].T,cmap='magma', extent=[-4,4,-4,4])\n", "plt.savefig(\"snaps/snap20\".format(i))\n", "#camera.snap()" ] }, { "cell_type": "code", "execution_count": null, "id": "01b762ae", "metadata": {}, "outputs": [], "source": [ "animation = camera.animate()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a04f0374", "metadata": {}, "outputs": [], "source": [ "animation" ] }, { "cell_type": "code", "execution_count": null, "id": "570e6f71", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d5605bc6", "metadata": {}, "outputs": [], "source": [ "#writergif = animation.PillowWriter(fps=30)\n", "animation.save('flow.gif',writer='imagemagick')" ] }, { "cell_type": "code", "execution_count": null, "id": "19a7e5a5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jaxdf", "language": "python", "name": "jaxdf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }