In [1]:
%matplotlib inline
import os
import time
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import trange
from jax import jvp, value_and_grad
from flax import linen as nn
from typing import Sequence
from functools import partial

In [2]:
class CPPINN(nn.Module):
    features: Sequence[int]
 
    #bases = [bases_x,bases_y,bases_z]
    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)
            
            outputs += [jnp.transpose(X, (1, 0))]

        xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])
        return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])
    
class TTPINN(nn.Module):
    features: Sequence[int]
 
    #bases = [bases_x,bases_y,bases_z]
    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_uniform()
        for i,X in enumerate(inputs):
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            if i==0:
                X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)
                
            else:
                X = nn.Dense(self.features[-1], kernel_init=init)(X)
            outputs += [X]
                
        #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])
        #print(mid.shape)
        #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])
        #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])
        return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])

    
class TuckerPINN(nn.Module):
    features: Sequence[int]
 
    def setup(self):
        # Initialize learnable parameters
        #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))
        self.core = self.param("core",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))

    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)
            
            outputs += [jnp.transpose(X, (1, 0))]
            #mid = jnp.einsum("fx,fy->fxy",outputs[0],outputs[1])
        return jnp.einsum("klm,kx,ly,mz->xyz",self.core,outputs[0],outputs[1],outputs[-1])
    
    
    

In [3]:
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


In [4]:

def loss_helmholtz(apply_fn, *train_data):
    def residual_loss(params, x, y, z, source_term, lda=1.):
        # compute u
        u = apply_fn(params, x, y, z)
        # tangent vector dx/dx
        v_x = jnp.ones(x.shape)
        v_y = jnp.ones(y.shape)
        v_z = jnp.ones(z.shape)
        # 2nd derivatives of u
        uxx = hvp_fwdfwd(lambda x: apply_fn(params, x, y, z), (x,), (v_x,))
        uyy = hvp_fwdfwd(lambda y: apply_fn(params, x, y, z), (y,), (v_y,))
        uzz = hvp_fwdfwd(lambda z: apply_fn(params, x, y, z), (z,), (v_z,))
        return jnp.mean(((uzz + uyy + uxx + lda*u) - source_term)**2)

    def boundary_loss(params, x, y, z):
        loss = 0.
        for i in range(6):
            loss += jnp.mean(apply_fn(params, x[i], y[i], z[i])**2)
        return loss
    print("Received",len(train_data))

    # unpack data
    xc, yc, zc, uc, xb, yb, zb = train_data

    # isolate loss func from redundant arguments

    loss_fn = lambda params: residual_loss(params, xc, yc, zc, uc) + boundary_loss(params, xb, yb, zb)
    #loss, gradient = jax.value_and_grad(loss_fn)(params)

    return loss_fn

In [5]:
def helmholtz3d_exact_u(a1, a2, a3, x, y, z):
    return jnp.sin(a1*jnp.pi*x) * jnp.sin(a2*jnp.pi*y) * jnp.sin(a3*jnp.pi*z)

# 3d time-independent helmholtz source term
def helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.):
    u_gt = helmholtz3d_exact_u(a1, a2, a3, x, y, z)
    uxx = -(a1*jnp.pi)**2 * u_gt
    uyy = -(a2*jnp.pi)**2 * u_gt
    uzz = -(a3*jnp.pi)**2 * u_gt
    return uxx + uyy + uzz + lda*u_gt


In [6]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

In [7]:
def train_generator_helmholtz3d(a1, a2, a3, nc, key):
    keys = jax.random.split(key, 3)
    # collocation points
    xc = jax.random.uniform(keys[0], (nc,), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[1], (nc,), minval=-1., maxval=1.)
    zc = jax.random.uniform(keys[2], (nc,), minval=-1., maxval=1.)
    # source term
    xcm, ycm, zcm = jnp.meshgrid(xc, yc, zc, indexing='ij')
    uc = helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm)
    xc, yc, zc = xc.reshape(-1, 1), yc.reshape(-1, 1), zc.reshape(-1, 1)
    # boundary (hard-coded)
    xb = [jnp.array([[1.]]), jnp.array([[-1.]]), xc, xc, xc, xc]
    yb = [yc, yc, jnp.array([[1.]]), jnp.array([[-1.]]), yc, yc]
    zb = [zc, zc, zc, zc, jnp.array([[1.]]), jnp.array([[-1.]])]
    return xc,yc,zc,uc,xb,yb,zb

# optimizer step function
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

def _test_generator_helmholtz3d(a1, a2, a3, nc_test):
    x = jnp.linspace(-1., 1., nc_test)
    y = jnp.linspace(-1., 1., nc_test)
    z = jnp.linspace(-1., 1., nc_test)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    z = jax.lax.stop_gradient(z)
    xm, ym, zm = jnp.meshgrid(x, y, z, indexing='ij')
    u_gt = helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    z = z.reshape(-1, 1)
    return x, y, z, xm, ym, zm ,u_gt

def plotter(xm,ym,zm,u):
    #xm, ym, zm = jnp.meshgrid(x,y,z, indexing='ij')
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='seismic',vmin=-1,vmax=+1)
    #im2 = ax.scatter(0,0,0,c=-1,s=100,cmap='seismic')
    ax.set_title('U(x, y, z)', fontsize=20)
    ax.set_xlabel('x', fontsize=18, labelpad=10)
    ax.set_ylabel('y', fontsize=18, labelpad=10)
    ax.set_zlabel('z', fontsize=18, labelpad=10)
    fig.colorbar(im,ax=ax)
    plt.show()

In [8]:
def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):
    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # feature sizes
    feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))

    #model = RBFPINN(FEATURES,linear)
    # make & init model
    if mode == "CPPINN":
        model = CPPINN(feat_sizes)
    elif mode == "TTPINN":
        model = TTPINN(feat_sizes)
    elif mode == "TuckerPINN":
        model = TuckerPINN(feat_sizes)
    params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))
    # optimizer
    optim = optax.adam(LR)
    state = optim.init(params)

    # dataset
    key, subkey = jax.random.split(key, 2)
    train_data = train_generator_helmholtz3d(4,4,3,NC, subkey)
    print(len(train_data))
    x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)
    #print(t,x,y)
    logger =[]

    # forward & loss function
    apply_fn = jax.jit(model.apply)
    #print(len(*train_data))
    loss_fn = loss_helmholtz(apply_fn, *train_data)

    @jax.jit
    def train_one_step(params, state):
        # compute loss and gradient
        loss, gradient = value_and_grad(loss_fn)(params)
        # update state
        params, state = update_model(optim, gradient, params, state)
        return loss, params, state
    
    start = time.time()
    for e in trange(1, EPOCHS+1):
        # single run
        loss, params, state = train_one_step(params, state)
        if e % LOG_ITER == 0 or e == 1:
            u = apply_fn(params, x,y,z)
            error = relative_l2(u, u_gt)
            print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')
            logger.append([e,loss,error])
            print('Solution:')
            u = apply_fn(params, x,y,z)
            plotter(xm,ym,zm,u)
        
    end = time.time()
    print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')

    print('Solution:')
    u = apply_fn(params, x,y,z)
    plotter(xm,ym,zm,u)
    return logger