|
@@ -0,0 +1,359 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "9f47ff8c",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "%matplotlib inline\n",
|
|
|
+ "import os\n",
|
|
|
+ "import time\n",
|
|
|
+ "import jax\n",
|
|
|
+ "import jax.numpy as jnp\n",
|
|
|
+ "import optax\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "from tqdm import trange\n",
|
|
|
+ "from jax import jvp, value_and_grad\n",
|
|
|
+ "from flax import linen as nn\n",
|
|
|
+ "from typing import Sequence\n",
|
|
|
+ "from functools import partial"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "id": "66d0f8e8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class CPPINN(nn.Module):\n",
|
|
|
+ " features: Sequence[int]\n",
|
|
|
+ " \n",
|
|
|
+ " #bases = [bases_x,bases_y,bases_z]\n",
|
|
|
+ " @nn.compact\n",
|
|
|
+ " def __call__(self, x, y, z):\n",
|
|
|
+ " inputs, outputs = [x, y, z], []\n",
|
|
|
+ " init = nn.initializers.xavier_normal()\n",
|
|
|
+ " for X in inputs:\n",
|
|
|
+ " for fs in self.features[:-1]:\n",
|
|
|
+ " X = nn.Dense(fs, kernel_init=init)(X)\n",
|
|
|
+ " X = nn.activation.tanh(X)\n",
|
|
|
+ " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
|
|
|
+ " \n",
|
|
|
+ " outputs += [jnp.transpose(X, (1, 0))]\n",
|
|
|
+ "\n",
|
|
|
+ " xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])\n",
|
|
|
+ " return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])\n",
|
|
|
+ " \n",
|
|
|
+ "class TTPINN(nn.Module):\n",
|
|
|
+ " features: Sequence[int]\n",
|
|
|
+ " \n",
|
|
|
+ " #bases = [bases_x,bases_y,bases_z]\n",
|
|
|
+ " @nn.compact\n",
|
|
|
+ " def __call__(self, x, y, z):\n",
|
|
|
+ " inputs, outputs = [x, y, z], []\n",
|
|
|
+ " init = nn.initializers.xavier_uniform()\n",
|
|
|
+ " for i,X in enumerate(inputs):\n",
|
|
|
+ " for fs in self.features[:-1]:\n",
|
|
|
+ " X = nn.Dense(fs, kernel_init=init)(X)\n",
|
|
|
+ " X = nn.activation.tanh(X)\n",
|
|
|
+ " if i==0:\n",
|
|
|
+ " X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)\n",
|
|
|
+ " \n",
|
|
|
+ " else:\n",
|
|
|
+ " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
|
|
|
+ " outputs += [X]\n",
|
|
|
+ " \n",
|
|
|
+ " #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])\n",
|
|
|
+ " #print(mid.shape)\n",
|
|
|
+ " #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])\n",
|
|
|
+ " #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])\n",
|
|
|
+ " return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])\n",
|
|
|
+ "\n",
|
|
|
+ " \n",
|
|
|
+ "class TuckerPINN(nn.Module):\n",
|
|
|
+ " features: Sequence[int]\n",
|
|
|
+ " \n",
|
|
|
+ " def setup(self):\n",
|
|
|
+ " # Initialize learnable parameters\n",
|
|
|
+ " #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))\n",
|
|
|
+ " self.core = self.param(\"core\",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ " @nn.compact\n",
|
|
|
+ " def __call__(self, x, y, z):\n",
|
|
|
+ " inputs, outputs = [x, y, z], []\n",
|
|
|
+ " init = nn.initializers.xavier_normal()\n",
|
|
|
+ " for X in inputs:\n",
|
|
|
+ " for fs in self.features[:-1]:\n",
|
|
|
+ " X = nn.Dense(fs, kernel_init=init)(X)\n",
|
|
|
+ " X = nn.activation.tanh(X)\n",
|
|
|
+ " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
|
|
|
+ " \n",
|
|
|
+ " outputs += [jnp.transpose(X, (1, 0))]\n",
|
|
|
+ " #mid = jnp.einsum(\"fx,fy->fxy\",outputs[0],outputs[1])\n",
|
|
|
+ " return jnp.einsum(\"klm,kx,ly,mz->xyz\",self.core,outputs[0],outputs[1],outputs[-1])\n",
|
|
|
+ " \n",
|
|
|
+ " \n",
|
|
|
+ " "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "id": "4f92234e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def hvp_fwdfwd(f, primals, tangents, return_primals=False):\n",
|
|
|
+ " g = lambda primals: jvp(f, (primals,), tangents)[1]\n",
|
|
|
+ " primals_out, tangents_out = jvp(g, primals, tangents)\n",
|
|
|
+ " if return_primals:\n",
|
|
|
+ " return primals_out, tangents_out\n",
|
|
|
+ " else:\n",
|
|
|
+ " return tangents_out\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "271f17de",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "def loss_helmholtz(apply_fn, *train_data):\n",
|
|
|
+ " def residual_loss(params, x, y, z, source_term, lda=1.):\n",
|
|
|
+ " # compute u\n",
|
|
|
+ " u = apply_fn(params, x, y, z)\n",
|
|
|
+ " # tangent vector dx/dx\n",
|
|
|
+ " v_x = jnp.ones(x.shape)\n",
|
|
|
+ " v_y = jnp.ones(y.shape)\n",
|
|
|
+ " v_z = jnp.ones(z.shape)\n",
|
|
|
+ " # 2nd derivatives of u\n",
|
|
|
+ " uxx = hvp_fwdfwd(lambda x: apply_fn(params, x, y, z), (x,), (v_x,))\n",
|
|
|
+ " uyy = hvp_fwdfwd(lambda y: apply_fn(params, x, y, z), (y,), (v_y,))\n",
|
|
|
+ " uzz = hvp_fwdfwd(lambda z: apply_fn(params, x, y, z), (z,), (v_z,))\n",
|
|
|
+ " return jnp.mean(((uzz + uyy + uxx + lda*u) - source_term)**2)\n",
|
|
|
+ "\n",
|
|
|
+ " def boundary_loss(params, x, y, z):\n",
|
|
|
+ " loss = 0.\n",
|
|
|
+ " for i in range(6):\n",
|
|
|
+ " loss += jnp.mean(apply_fn(params, x[i], y[i], z[i])**2)\n",
|
|
|
+ " return loss\n",
|
|
|
+ " print(\"Received\",len(train_data))\n",
|
|
|
+ "\n",
|
|
|
+ " # unpack data\n",
|
|
|
+ " xc, yc, zc, uc, xb, yb, zb = train_data\n",
|
|
|
+ "\n",
|
|
|
+ " # isolate loss func from redundant arguments\n",
|
|
|
+ "\n",
|
|
|
+ " loss_fn = lambda params: residual_loss(params, xc, yc, zc, uc) + boundary_loss(params, xb, yb, zb)\n",
|
|
|
+ " #loss, gradient = jax.value_and_grad(loss_fn)(params)\n",
|
|
|
+ "\n",
|
|
|
+ " return loss_fn"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "e3ec2a80",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def helmholtz3d_exact_u(a1, a2, a3, x, y, z):\n",
|
|
|
+ " return jnp.sin(a1*jnp.pi*x) * jnp.sin(a2*jnp.pi*y) * jnp.sin(a3*jnp.pi*z)\n",
|
|
|
+ "\n",
|
|
|
+ "# 3d time-independent helmholtz source term\n",
|
|
|
+ "def helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.):\n",
|
|
|
+ " u_gt = helmholtz3d_exact_u(a1, a2, a3, x, y, z)\n",
|
|
|
+ " uxx = -(a1*jnp.pi)**2 * u_gt\n",
|
|
|
+ " uyy = -(a2*jnp.pi)**2 * u_gt\n",
|
|
|
+ " uzz = -(a3*jnp.pi)**2 * u_gt\n",
|
|
|
+ " return uxx + uyy + uzz + lda*u_gt\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "41074e9b",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def relative_l2(u, u_gt):\n",
|
|
|
+ " return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "id": "1ce9b588",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def train_generator_helmholtz3d(a1, a2, a3, nc, key):\n",
|
|
|
+ " keys = jax.random.split(key, 3)\n",
|
|
|
+ " # collocation points\n",
|
|
|
+ " xc = jax.random.uniform(keys[0], (nc,), minval=-1., maxval=1.)\n",
|
|
|
+ " yc = jax.random.uniform(keys[1], (nc,), minval=-1., maxval=1.)\n",
|
|
|
+ " zc = jax.random.uniform(keys[2], (nc,), minval=-1., maxval=1.)\n",
|
|
|
+ " # source term\n",
|
|
|
+ " xcm, ycm, zcm = jnp.meshgrid(xc, yc, zc, indexing='ij')\n",
|
|
|
+ " uc = helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm)\n",
|
|
|
+ " xc, yc, zc = xc.reshape(-1, 1), yc.reshape(-1, 1), zc.reshape(-1, 1)\n",
|
|
|
+ " # boundary (hard-coded)\n",
|
|
|
+ " xb = [jnp.array([[1.]]), jnp.array([[-1.]]), xc, xc, xc, xc]\n",
|
|
|
+ " yb = [yc, yc, jnp.array([[1.]]), jnp.array([[-1.]]), yc, yc]\n",
|
|
|
+ " zb = [zc, zc, zc, zc, jnp.array([[1.]]), jnp.array([[-1.]])]\n",
|
|
|
+ " return xc,yc,zc,uc,xb,yb,zb\n",
|
|
|
+ "\n",
|
|
|
+ "# optimizer step function\n",
|
|
|
+ "@partial(jax.jit, static_argnums=(0,))\n",
|
|
|
+ "def update_model(optim, gradient, params, state):\n",
|
|
|
+ " updates, state = optim.update(gradient, state)\n",
|
|
|
+ " params = optax.apply_updates(params, updates)\n",
|
|
|
+ " return params, state\n",
|
|
|
+ "\n",
|
|
|
+ "def _test_generator_helmholtz3d(a1, a2, a3, nc_test):\n",
|
|
|
+ " x = jnp.linspace(-1., 1., nc_test)\n",
|
|
|
+ " y = jnp.linspace(-1., 1., nc_test)\n",
|
|
|
+ " z = jnp.linspace(-1., 1., nc_test)\n",
|
|
|
+ " x = jax.lax.stop_gradient(x)\n",
|
|
|
+ " y = jax.lax.stop_gradient(y)\n",
|
|
|
+ " z = jax.lax.stop_gradient(z)\n",
|
|
|
+ " xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')\n",
|
|
|
+ " u_gt = helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm)\n",
|
|
|
+ " x = x.reshape(-1, 1)\n",
|
|
|
+ " y = y.reshape(-1, 1)\n",
|
|
|
+ " z = z.reshape(-1, 1)\n",
|
|
|
+ " return x, y, z, xm, ym, zm ,u_gt\n",
|
|
|
+ "\n",
|
|
|
+ "def plotter(xm,ym,zm,u):\n",
|
|
|
+ " #xm, ym, zm = jnp.meshgrid(x,y,z, indexing='ij')\n",
|
|
|
+ " fig = plt.figure(figsize=(6, 6))\n",
|
|
|
+ " ax = fig.add_subplot(111, projection='3d')\n",
|
|
|
+ " im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='seismic',vmin=-1,vmax=+1)\n",
|
|
|
+ " #im2 = ax.scatter(0,0,0,c=-1,s=100,cmap='seismic')\n",
|
|
|
+ " ax.set_title('U(x, y, z)', fontsize=20)\n",
|
|
|
+ " ax.set_xlabel('x', fontsize=18, labelpad=10)\n",
|
|
|
+ " ax.set_ylabel('y', fontsize=18, labelpad=10)\n",
|
|
|
+ " ax.set_zlabel('z', fontsize=18, labelpad=10)\n",
|
|
|
+ " fig.colorbar(im,ax=ax)\n",
|
|
|
+ " plt.show()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "id": "7d0919fa",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):\n",
|
|
|
+ " # force jax to use one device\n",
|
|
|
+ " os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
|
|
+ " os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
|
|
|
+ "\n",
|
|
|
+ " # random key\n",
|
|
|
+ " key = jax.random.PRNGKey(SEED)\n",
|
|
|
+ " key, subkey = jax.random.split(key, 2)\n",
|
|
|
+ "\n",
|
|
|
+ " # feature sizes\n",
|
|
|
+ " feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))\n",
|
|
|
+ "\n",
|
|
|
+ " #model = RBFPINN(FEATURES,linear)\n",
|
|
|
+ " # make & init model\n",
|
|
|
+ " if mode == \"CPPINN\":\n",
|
|
|
+ " model = CPPINN(feat_sizes)\n",
|
|
|
+ " elif mode == \"TTPINN\":\n",
|
|
|
+ " model = TTPINN(feat_sizes)\n",
|
|
|
+ " elif mode == \"TuckerPINN\":\n",
|
|
|
+ " model = TuckerPINN(feat_sizes)\n",
|
|
|
+ " params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))\n",
|
|
|
+ " # optimizer\n",
|
|
|
+ " optim = optax.adam(LR)\n",
|
|
|
+ " state = optim.init(params)\n",
|
|
|
+ "\n",
|
|
|
+ " # dataset\n",
|
|
|
+ " key, subkey = jax.random.split(key, 2)\n",
|
|
|
+ " train_data = train_generator_helmholtz3d(4,4,3,NC, subkey)\n",
|
|
|
+ " print(len(train_data))\n",
|
|
|
+ " x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)\n",
|
|
|
+ " #print(t,x,y)\n",
|
|
|
+ " logger =[]\n",
|
|
|
+ "\n",
|
|
|
+ " # forward & loss function\n",
|
|
|
+ " apply_fn = jax.jit(model.apply)\n",
|
|
|
+ " #print(len(*train_data))\n",
|
|
|
+ " loss_fn = loss_helmholtz(apply_fn, *train_data)\n",
|
|
|
+ "\n",
|
|
|
+ " @jax.jit\n",
|
|
|
+ " def train_one_step(params, state):\n",
|
|
|
+ " # compute loss and gradient\n",
|
|
|
+ " loss, gradient = value_and_grad(loss_fn)(params)\n",
|
|
|
+ " # update state\n",
|
|
|
+ " params, state = update_model(optim, gradient, params, state)\n",
|
|
|
+ " return loss, params, state\n",
|
|
|
+ " \n",
|
|
|
+ " start = time.time()\n",
|
|
|
+ " for e in trange(1, EPOCHS+1):\n",
|
|
|
+ " # single run\n",
|
|
|
+ " loss, params, state = train_one_step(params, state)\n",
|
|
|
+ " if e % LOG_ITER == 0 or e == 1:\n",
|
|
|
+ " u = apply_fn(params, x,y,z)\n",
|
|
|
+ " error = relative_l2(u, u_gt)\n",
|
|
|
+ " print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n",
|
|
|
+ " logger.append([e,loss,error])\n",
|
|
|
+ " print('Solution:')\n",
|
|
|
+ " u = apply_fn(params, x,y,z)\n",
|
|
|
+ " plotter(xm,ym,zm,u)\n",
|
|
|
+ " \n",
|
|
|
+ " end = time.time()\n",
|
|
|
+ " print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n",
|
|
|
+ "\n",
|
|
|
+ " print('Solution:')\n",
|
|
|
+ " u = apply_fn(params, x,y,z)\n",
|
|
|
+ " plotter(xm,ym,zm,u)\n",
|
|
|
+ " return logger"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "298aad40",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "858faa8e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "jaxdf",
|
|
|
+ "language": "python",
|
|
|
+ "name": "jaxdf"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.12.2"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|