Ver código fonte

Helmholtz 3D

vemuri 1 ano atrás
pai
commit
b0fdae072a
2 arquivos alterados com 543 adições e 0 exclusões
  1. 359 0
      Helmholtz/Helmholtz.ipynb
  2. 184 0
      Helmholtz/PINA_Helmholtz.ipynb

+ 359 - 0
Helmholtz/Helmholtz.ipynb

@@ -0,0 +1,359 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "9f47ff8c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "import os\n",
+    "import time\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "import optax\n",
+    "import matplotlib.pyplot as plt\n",
+    "from tqdm import trange\n",
+    "from jax import jvp, value_and_grad\n",
+    "from flax import linen as nn\n",
+    "from typing import Sequence\n",
+    "from functools import partial"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "66d0f8e8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class CPPINN(nn.Module):\n",
+    "    features: Sequence[int]\n",
+    " \n",
+    "    #bases = [bases_x,bases_y,bases_z]\n",
+    "    @nn.compact\n",
+    "    def __call__(self, x, y, z):\n",
+    "        inputs, outputs = [x, y, z], []\n",
+    "        init = nn.initializers.xavier_normal()\n",
+    "        for X in inputs:\n",
+    "            for fs in self.features[:-1]:\n",
+    "                X = nn.Dense(fs, kernel_init=init)(X)\n",
+    "                X = nn.activation.tanh(X)\n",
+    "            X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
+    "            \n",
+    "            outputs += [jnp.transpose(X, (1, 0))]\n",
+    "\n",
+    "        xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])\n",
+    "        return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])\n",
+    "    \n",
+    "class TTPINN(nn.Module):\n",
+    "    features: Sequence[int]\n",
+    " \n",
+    "    #bases = [bases_x,bases_y,bases_z]\n",
+    "    @nn.compact\n",
+    "    def __call__(self, x, y, z):\n",
+    "        inputs, outputs = [x, y, z], []\n",
+    "        init = nn.initializers.xavier_uniform()\n",
+    "        for i,X in enumerate(inputs):\n",
+    "            for fs in self.features[:-1]:\n",
+    "                X = nn.Dense(fs, kernel_init=init)(X)\n",
+    "                X = nn.activation.tanh(X)\n",
+    "            if i==0:\n",
+    "                X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)\n",
+    "                \n",
+    "            else:\n",
+    "                X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
+    "            outputs += [X]\n",
+    "                \n",
+    "        #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])\n",
+    "        #print(mid.shape)\n",
+    "        #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])\n",
+    "        #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])\n",
+    "        return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])\n",
+    "\n",
+    "    \n",
+    "class TuckerPINN(nn.Module):\n",
+    "    features: Sequence[int]\n",
+    " \n",
+    "    def setup(self):\n",
+    "        # Initialize learnable parameters\n",
+    "        #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))\n",
+    "        self.core = self.param(\"core\",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))\n",
+    "\n",
+    "    @nn.compact\n",
+    "    def __call__(self, x, y, z):\n",
+    "        inputs, outputs = [x, y, z], []\n",
+    "        init = nn.initializers.xavier_normal()\n",
+    "        for X in inputs:\n",
+    "            for fs in self.features[:-1]:\n",
+    "                X = nn.Dense(fs, kernel_init=init)(X)\n",
+    "                X = nn.activation.tanh(X)\n",
+    "            X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
+    "            \n",
+    "            outputs += [jnp.transpose(X, (1, 0))]\n",
+    "            #mid = jnp.einsum(\"fx,fy->fxy\",outputs[0],outputs[1])\n",
+    "        return jnp.einsum(\"klm,kx,ly,mz->xyz\",self.core,outputs[0],outputs[1],outputs[-1])\n",
+    "    \n",
+    "    \n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "4f92234e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def hvp_fwdfwd(f, primals, tangents, return_primals=False):\n",
+    "    g = lambda primals: jvp(f, (primals,), tangents)[1]\n",
+    "    primals_out, tangents_out = jvp(g, primals, tangents)\n",
+    "    if return_primals:\n",
+    "        return primals_out, tangents_out\n",
+    "    else:\n",
+    "        return tangents_out\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "271f17de",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def loss_helmholtz(apply_fn, *train_data):\n",
+    "    def residual_loss(params, x, y, z, source_term, lda=1.):\n",
+    "        # compute u\n",
+    "        u = apply_fn(params, x, y, z)\n",
+    "        # tangent vector dx/dx\n",
+    "        v_x = jnp.ones(x.shape)\n",
+    "        v_y = jnp.ones(y.shape)\n",
+    "        v_z = jnp.ones(z.shape)\n",
+    "        # 2nd derivatives of u\n",
+    "        uxx = hvp_fwdfwd(lambda x: apply_fn(params, x, y, z), (x,), (v_x,))\n",
+    "        uyy = hvp_fwdfwd(lambda y: apply_fn(params, x, y, z), (y,), (v_y,))\n",
+    "        uzz = hvp_fwdfwd(lambda z: apply_fn(params, x, y, z), (z,), (v_z,))\n",
+    "        return jnp.mean(((uzz + uyy + uxx + lda*u) - source_term)**2)\n",
+    "\n",
+    "    def boundary_loss(params, x, y, z):\n",
+    "        loss = 0.\n",
+    "        for i in range(6):\n",
+    "            loss += jnp.mean(apply_fn(params, x[i], y[i], z[i])**2)\n",
+    "        return loss\n",
+    "    print(\"Received\",len(train_data))\n",
+    "\n",
+    "    # unpack data\n",
+    "    xc, yc, zc, uc, xb, yb, zb = train_data\n",
+    "\n",
+    "    # isolate loss func from redundant arguments\n",
+    "\n",
+    "    loss_fn = lambda params: residual_loss(params, xc, yc, zc, uc) + boundary_loss(params, xb, yb, zb)\n",
+    "    #loss, gradient = jax.value_and_grad(loss_fn)(params)\n",
+    "\n",
+    "    return loss_fn"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "e3ec2a80",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def helmholtz3d_exact_u(a1, a2, a3, x, y, z):\n",
+    "    return jnp.sin(a1*jnp.pi*x) * jnp.sin(a2*jnp.pi*y) * jnp.sin(a3*jnp.pi*z)\n",
+    "\n",
+    "# 3d time-independent helmholtz source term\n",
+    "def helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.):\n",
+    "    u_gt = helmholtz3d_exact_u(a1, a2, a3, x, y, z)\n",
+    "    uxx = -(a1*jnp.pi)**2 * u_gt\n",
+    "    uyy = -(a2*jnp.pi)**2 * u_gt\n",
+    "    uzz = -(a3*jnp.pi)**2 * u_gt\n",
+    "    return uxx + uyy + uzz + lda*u_gt\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "41074e9b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def relative_l2(u, u_gt):\n",
+    "    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "1ce9b588",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train_generator_helmholtz3d(a1, a2, a3, nc, key):\n",
+    "    keys = jax.random.split(key, 3)\n",
+    "    # collocation points\n",
+    "    xc = jax.random.uniform(keys[0], (nc,), minval=-1., maxval=1.)\n",
+    "    yc = jax.random.uniform(keys[1], (nc,), minval=-1., maxval=1.)\n",
+    "    zc = jax.random.uniform(keys[2], (nc,), minval=-1., maxval=1.)\n",
+    "    # source term\n",
+    "    xcm, ycm, zcm = jnp.meshgrid(xc, yc, zc, indexing='ij')\n",
+    "    uc = helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm)\n",
+    "    xc, yc, zc = xc.reshape(-1, 1), yc.reshape(-1, 1), zc.reshape(-1, 1)\n",
+    "    # boundary (hard-coded)\n",
+    "    xb = [jnp.array([[1.]]), jnp.array([[-1.]]), xc, xc, xc, xc]\n",
+    "    yb = [yc, yc, jnp.array([[1.]]), jnp.array([[-1.]]), yc, yc]\n",
+    "    zb = [zc, zc, zc, zc, jnp.array([[1.]]), jnp.array([[-1.]])]\n",
+    "    return xc,yc,zc,uc,xb,yb,zb\n",
+    "\n",
+    "# optimizer step function\n",
+    "@partial(jax.jit, static_argnums=(0,))\n",
+    "def update_model(optim, gradient, params, state):\n",
+    "    updates, state = optim.update(gradient, state)\n",
+    "    params = optax.apply_updates(params, updates)\n",
+    "    return params, state\n",
+    "\n",
+    "def _test_generator_helmholtz3d(a1, a2, a3, nc_test):\n",
+    "    x = jnp.linspace(-1., 1., nc_test)\n",
+    "    y = jnp.linspace(-1., 1., nc_test)\n",
+    "    z = jnp.linspace(-1., 1., nc_test)\n",
+    "    x = jax.lax.stop_gradient(x)\n",
+    "    y = jax.lax.stop_gradient(y)\n",
+    "    z = jax.lax.stop_gradient(z)\n",
+    "    xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')\n",
+    "    u_gt = helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm)\n",
+    "    x = x.reshape(-1, 1)\n",
+    "    y = y.reshape(-1, 1)\n",
+    "    z = z.reshape(-1, 1)\n",
+    "    return x, y, z, xm, ym, zm ,u_gt\n",
+    "\n",
+    "def plotter(xm,ym,zm,u):\n",
+    "    #xm, ym, zm = jnp.meshgrid(x,y,z, indexing='ij')\n",
+    "    fig = plt.figure(figsize=(6, 6))\n",
+    "    ax = fig.add_subplot(111, projection='3d')\n",
+    "    im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='seismic',vmin=-1,vmax=+1)\n",
+    "    #im2 = ax.scatter(0,0,0,c=-1,s=100,cmap='seismic')\n",
+    "    ax.set_title('U(x, y, z)', fontsize=20)\n",
+    "    ax.set_xlabel('x', fontsize=18, labelpad=10)\n",
+    "    ax.set_ylabel('y', fontsize=18, labelpad=10)\n",
+    "    ax.set_zlabel('z', fontsize=18, labelpad=10)\n",
+    "    fig.colorbar(im,ax=ax)\n",
+    "    plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "7d0919fa",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):\n",
+    "    # force jax to use one device\n",
+    "    os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
+    "    os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
+    "\n",
+    "    # random key\n",
+    "    key = jax.random.PRNGKey(SEED)\n",
+    "    key, subkey = jax.random.split(key, 2)\n",
+    "\n",
+    "    # feature sizes\n",
+    "    feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))\n",
+    "\n",
+    "    #model = RBFPINN(FEATURES,linear)\n",
+    "    # make & init model\n",
+    "    if mode == \"CPPINN\":\n",
+    "        model = CPPINN(feat_sizes)\n",
+    "    elif mode == \"TTPINN\":\n",
+    "        model = TTPINN(feat_sizes)\n",
+    "    elif mode == \"TuckerPINN\":\n",
+    "        model = TuckerPINN(feat_sizes)\n",
+    "    params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))\n",
+    "    # optimizer\n",
+    "    optim = optax.adam(LR)\n",
+    "    state = optim.init(params)\n",
+    "\n",
+    "    # dataset\n",
+    "    key, subkey = jax.random.split(key, 2)\n",
+    "    train_data = train_generator_helmholtz3d(4,4,3,NC, subkey)\n",
+    "    print(len(train_data))\n",
+    "    x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)\n",
+    "    #print(t,x,y)\n",
+    "    logger =[]\n",
+    "\n",
+    "    # forward & loss function\n",
+    "    apply_fn = jax.jit(model.apply)\n",
+    "    #print(len(*train_data))\n",
+    "    loss_fn = loss_helmholtz(apply_fn, *train_data)\n",
+    "\n",
+    "    @jax.jit\n",
+    "    def train_one_step(params, state):\n",
+    "        # compute loss and gradient\n",
+    "        loss, gradient = value_and_grad(loss_fn)(params)\n",
+    "        # update state\n",
+    "        params, state = update_model(optim, gradient, params, state)\n",
+    "        return loss, params, state\n",
+    "    \n",
+    "    start = time.time()\n",
+    "    for e in trange(1, EPOCHS+1):\n",
+    "        # single run\n",
+    "        loss, params, state = train_one_step(params, state)\n",
+    "        if e % LOG_ITER == 0 or e == 1:\n",
+    "            u = apply_fn(params, x,y,z)\n",
+    "            error = relative_l2(u, u_gt)\n",
+    "            print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n",
+    "            logger.append([e,loss,error])\n",
+    "            print('Solution:')\n",
+    "            u = apply_fn(params, x,y,z)\n",
+    "            plotter(xm,ym,zm,u)\n",
+    "        \n",
+    "    end = time.time()\n",
+    "    print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n",
+    "\n",
+    "    print('Solution:')\n",
+    "    u = apply_fn(params, x,y,z)\n",
+    "    plotter(xm,ym,zm,u)\n",
+    "    return logger"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "298aad40",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "858faa8e",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "jaxdf",
+   "language": "python",
+   "name": "jaxdf"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

Diferenças do arquivo suprimidas por serem muito extensas
+ 184 - 0
Helmholtz/PINA_Helmholtz.ipynb


Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff