{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9f47ff8c", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import time\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange\n", "from jax import jvp, value_and_grad\n", "from flax import linen as nn\n", "from typing import Sequence\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": 2, "id": "66d0f8e8", "metadata": {}, "outputs": [], "source": [ "class CPPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", "\n", " xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])\n", " return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])\n", " \n", "class TTPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_uniform()\n", " for i,X in enumerate(inputs):\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " if i==0:\n", " X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)\n", " \n", " else:\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " outputs += [X]\n", " \n", " #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])\n", " #print(mid.shape)\n", " #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])\n", " #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])\n", " return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])\n", "\n", " \n", "class TuckerPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " def setup(self):\n", " # Initialize learnable parameters\n", " #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))\n", " self.core = self.param(\"core\",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))\n", "\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", " #mid = jnp.einsum(\"fx,fy->fxy\",outputs[0],outputs[1])\n", " return jnp.einsum(\"klm,kx,ly,mz->xyz\",self.core,outputs[0],outputs[1],outputs[-1])\n", " \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "4f92234e", "metadata": {}, "outputs": [], "source": [ "def hvp_fwdfwd(f, primals, tangents, return_primals=False):\n", " g = lambda primals: jvp(f, (primals,), tangents)[1]\n", " primals_out, tangents_out = jvp(g, primals, tangents)\n", " if return_primals:\n", " return primals_out, tangents_out\n", " else:\n", " return tangents_out\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "271f17de", "metadata": {}, "outputs": [], "source": [ "\n", "def loss_helmholtz(apply_fn, *train_data):\n", " def residual_loss(params, x, y, z, source_term, lda=1.):\n", " # compute u\n", " u = apply_fn(params, x, y, z)\n", " # tangent vector dx/dx\n", " v_x = jnp.ones(x.shape)\n", " v_y = jnp.ones(y.shape)\n", " v_z = jnp.ones(z.shape)\n", " # 2nd derivatives of u\n", " uxx = hvp_fwdfwd(lambda x: apply_fn(params, x, y, z), (x,), (v_x,))\n", " uyy = hvp_fwdfwd(lambda y: apply_fn(params, x, y, z), (y,), (v_y,))\n", " uzz = hvp_fwdfwd(lambda z: apply_fn(params, x, y, z), (z,), (v_z,))\n", " return jnp.mean(((uzz + uyy + uxx + lda*u) - source_term)**2)\n", "\n", " def boundary_loss(params, x, y, z):\n", " loss = 0.\n", " for i in range(6):\n", " loss += jnp.mean(apply_fn(params, x[i], y[i], z[i])**2)\n", " return loss\n", " print(\"Received\",len(train_data))\n", "\n", " # unpack data\n", " xc, yc, zc, uc, xb, yb, zb = train_data\n", "\n", " # isolate loss func from redundant arguments\n", "\n", " loss_fn = lambda params: residual_loss(params, xc, yc, zc, uc) + boundary_loss(params, xb, yb, zb)\n", " #loss, gradient = jax.value_and_grad(loss_fn)(params)\n", "\n", " return loss_fn" ] }, { "cell_type": "code", "execution_count": 5, "id": "e3ec2a80", "metadata": {}, "outputs": [], "source": [ "def helmholtz3d_exact_u(a1, a2, a3, x, y, z):\n", " return jnp.sin(a1*jnp.pi*x) * jnp.sin(a2*jnp.pi*y) * jnp.sin(a3*jnp.pi*z)\n", "\n", "# 3d time-independent helmholtz source term\n", "def helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.):\n", " u_gt = helmholtz3d_exact_u(a1, a2, a3, x, y, z)\n", " uxx = -(a1*jnp.pi)**2 * u_gt\n", " uyy = -(a2*jnp.pi)**2 * u_gt\n", " uzz = -(a3*jnp.pi)**2 * u_gt\n", " return uxx + uyy + uzz + lda*u_gt\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "41074e9b", "metadata": {}, "outputs": [], "source": [ "def relative_l2(u, u_gt):\n", " return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)" ] }, { "cell_type": "code", "execution_count": 7, "id": "1ce9b588", "metadata": {}, "outputs": [], "source": [ "def train_generator_helmholtz3d(a1, a2, a3, nc, key):\n", " keys = jax.random.split(key, 3)\n", " # collocation points\n", " xc = jax.random.uniform(keys[0], (nc,), minval=-1., maxval=1.)\n", " yc = jax.random.uniform(keys[1], (nc,), minval=-1., maxval=1.)\n", " zc = jax.random.uniform(keys[2], (nc,), minval=-1., maxval=1.)\n", " # source term\n", " xcm, ycm, zcm = jnp.meshgrid(xc, yc, zc, indexing='ij')\n", " uc = helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm)\n", " xc, yc, zc = xc.reshape(-1, 1), yc.reshape(-1, 1), zc.reshape(-1, 1)\n", " # boundary (hard-coded)\n", " xb = [jnp.array([[1.]]), jnp.array([[-1.]]), xc, xc, xc, xc]\n", " yb = [yc, yc, jnp.array([[1.]]), jnp.array([[-1.]]), yc, yc]\n", " zb = [zc, zc, zc, zc, jnp.array([[1.]]), jnp.array([[-1.]])]\n", " return xc,yc,zc,uc,xb,yb,zb\n", "\n", "# optimizer step function\n", "@partial(jax.jit, static_argnums=(0,))\n", "def update_model(optim, gradient, params, state):\n", " updates, state = optim.update(gradient, state)\n", " params = optax.apply_updates(params, updates)\n", " return params, state\n", "\n", "def _test_generator_helmholtz3d(a1, a2, a3, nc_test):\n", " x = jnp.linspace(-1., 1., nc_test)\n", " y = jnp.linspace(-1., 1., nc_test)\n", " z = jnp.linspace(-1., 1., nc_test)\n", " x = jax.lax.stop_gradient(x)\n", " y = jax.lax.stop_gradient(y)\n", " z = jax.lax.stop_gradient(z)\n", " xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')\n", " u_gt = helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm)\n", " x = x.reshape(-1, 1)\n", " y = y.reshape(-1, 1)\n", " z = z.reshape(-1, 1)\n", " return x, y, z, xm, ym, zm ,u_gt\n", "\n", "def plotter(xm,ym,zm,u):\n", " #xm, ym, zm = jnp.meshgrid(x,y,z, indexing='ij')\n", " fig = plt.figure(figsize=(6, 6))\n", " ax = fig.add_subplot(111, projection='3d')\n", " im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='seismic',vmin=-1,vmax=+1)\n", " #im2 = ax.scatter(0,0,0,c=-1,s=100,cmap='seismic')\n", " ax.set_title('U(x, y, z)', fontsize=20)\n", " ax.set_xlabel('x', fontsize=18, labelpad=10)\n", " ax.set_ylabel('y', fontsize=18, labelpad=10)\n", " ax.set_zlabel('z', fontsize=18, labelpad=10)\n", " fig.colorbar(im,ax=ax)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "id": "7d0919fa", "metadata": {}, "outputs": [], "source": [ "def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):\n", " # force jax to use one device\n", " os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", " os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n", "\n", " # random key\n", " key = jax.random.PRNGKey(SEED)\n", " key, subkey = jax.random.split(key, 2)\n", "\n", " # feature sizes\n", " feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))\n", "\n", " #model = RBFPINN(FEATURES,linear)\n", " # make & init model\n", " if mode == \"CPPINN\":\n", " model = CPPINN(feat_sizes)\n", " elif mode == \"TTPINN\":\n", " model = TTPINN(feat_sizes)\n", " elif mode == \"TuckerPINN\":\n", " model = TuckerPINN(feat_sizes)\n", " params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))\n", " # optimizer\n", " optim = optax.adam(LR)\n", " state = optim.init(params)\n", "\n", " # dataset\n", " key, subkey = jax.random.split(key, 2)\n", " train_data = train_generator_helmholtz3d(4,4,3,NC, subkey)\n", " print(len(train_data))\n", " x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)\n", " #print(t,x,y)\n", " logger =[]\n", "\n", " # forward & loss function\n", " apply_fn = jax.jit(model.apply)\n", " #print(len(*train_data))\n", " loss_fn = loss_helmholtz(apply_fn, *train_data)\n", "\n", " @jax.jit\n", " def train_one_step(params, state):\n", " # compute loss and gradient\n", " loss, gradient = value_and_grad(loss_fn)(params)\n", " # update state\n", " params, state = update_model(optim, gradient, params, state)\n", " return loss, params, state\n", " \n", " start = time.time()\n", " for e in trange(1, EPOCHS+1):\n", " # single run\n", " loss, params, state = train_one_step(params, state)\n", " if e % LOG_ITER == 0 or e == 1:\n", " u = apply_fn(params, x,y,z)\n", " error = relative_l2(u, u_gt)\n", " print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n", " logger.append([e,loss,error])\n", " print('Solution:')\n", " u = apply_fn(params, x,y,z)\n", " plotter(xm,ym,zm,u)\n", " \n", " end = time.time()\n", " print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n", "\n", " print('Solution:')\n", " u = apply_fn(params, x,y,z)\n", " plotter(xm,ym,zm,u)\n", " return logger" ] }, { "cell_type": "code", "execution_count": null, "id": "298aad40", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "858faa8e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jaxdf", "language": "python", "name": "jaxdf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }