{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "59a27db9", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import time\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "import matplotlib.pyplot as plt\n", "from tqdm import trange\n", "from jax import jvp, value_and_grad\n", "from flax import linen as nn\n", "from typing import Sequence\n", "from functools import partial\n", "import numpy as np " ] }, { "cell_type": "code", "execution_count": 2, "id": "fa985589", "metadata": {}, "outputs": [], "source": [ "class CPPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", "\n", " xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])\n", " return jnp.einsum('fxy, fz->xyz', xy, outputs[-1])\n", " \n", "class TTPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " #bases = [bases_x,bases_y,bases_z]\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_uniform()\n", " for i,X in enumerate(inputs):\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " if i==0:\n", " X = nn.DenseGeneral((self.features[-1],self.features[-1]), kernel_init=init)(X)\n", " \n", " else:\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " outputs += [X]\n", " \n", " #mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])\n", " #print(mid.shape)\n", " #mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])\n", " #xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])\n", " return jnp.einsum('xfk,yf,zk->xyz',outputs[0],outputs[1],outputs[-1])\n", "\n", " \n", "class TuckerPINN(nn.Module):\n", " features: Sequence[int]\n", " \n", " def setup(self):\n", " # Initialize learnable parameters\n", " #self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))\n", " self.core = self.param(\"core\",nn.initializers.orthogonal(),(self.features[-1],self.features[-1],self.features[-1]))\n", "\n", " @nn.compact\n", " def __call__(self, x, y, z):\n", " inputs, outputs = [x, y, z], []\n", " init = nn.initializers.xavier_normal()\n", " for X in inputs:\n", " for fs in self.features[:-1]:\n", " X = nn.Dense(fs, kernel_init=init)(X)\n", " X = nn.activation.tanh(X)\n", " X = nn.Dense(self.features[-1], kernel_init=init)(X)\n", " \n", " outputs += [jnp.transpose(X, (1, 0))]\n", " #mid = jnp.einsum(\"fx,fy->fxy\",outputs[0],outputs[1])\n", " return jnp.einsum(\"klm,kx,ly,mz->xyz\",self.core,outputs[0],outputs[1],outputs[-1])\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "a5bdd11a", "metadata": {}, "outputs": [], "source": [ "def hvp_fwdfwd(f, primals, tangents, return_primals=False):\n", " g = lambda primals: jvp(f, (primals,), tangents)[1]\n", " primals_out, tangents_out = jvp(g, primals, tangents)\n", " if return_primals:\n", " return primals_out, tangents_out\n", " else:\n", " return tangents_out" ] }, { "cell_type": "code", "execution_count": 4, "id": "bbf6a3d0", "metadata": {}, "outputs": [], "source": [ "def relative_l2(u, u_gt):\n", " return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)" ] }, { "cell_type": "code", "execution_count": 5, "id": "983f566f", "metadata": {}, "outputs": [], "source": [ "def loss_flowmixer(apply_fn, *train_data):\n", " def residual_loss(params, t, x, y, a, b):\n", " # tangent vector dx/dx\n", " v_t = jnp.ones(t.shape)\n", " v_x = jnp.ones(x.shape)\n", " v_y = jnp.ones(y.shape)\n", " # 1st derivatives of u\n", " ut = jvp(lambda t: apply_fn(params, t, x, y), (t,), (v_t,))[1]\n", " ux = jvp(lambda x: apply_fn(params, t, x, y), (x,), (v_x,))[1]\n", " uy = jvp(lambda y: apply_fn(params, t, x, y), (y,), (v_y,))[1]\n", " return jnp.mean((ut + a*ux + b*uy)**2)\n", "\n", " def initial_loss(params, t, x, y, u):\n", " return jnp.mean((apply_fn(params, t, x, y) - u)**2)\n", "\n", " def boundary_loss(params, t, x, y, u):\n", " loss = 0.\n", " for i in range(4):\n", " loss += jnp.mean((apply_fn(params, t[i], x[i], y[i]) - u[i])**2)\n", " return loss\n", "\n", " # unpack data\n", " print(\"Received\",len(train_data))\n", "\n", " tc, xc, yc, ti, xi, yi, ui, tb, xb, yb, ub, a, b = train_data\n", "\n", " # isolate loss func from redundant arguments\n", " loss_fn = lambda params: 10*residual_loss(params, tc, xc, yc, a, b) + \\\n", " initial_loss(params, ti, xi, yi, ui) + \\\n", " boundary_loss(params, tb, xb, yb, ub)\n", "\n", " #loss, gradient = jax.value_and_grad(loss_fn)(params)\n", "\n", " return loss_fn" ] }, { "cell_type": "code", "execution_count": 6, "id": "41373518", "metadata": {}, "outputs": [], "source": [ "@partial(jax.jit, static_argnums=(0,))\n", "def update_model(optim, gradient, params, state):\n", " updates, state = optim.update(gradient, state)\n", " params = optax.apply_updates(params, updates)\n", " return params, state" ] }, { "cell_type": "code", "execution_count": 7, "id": "861f631d", "metadata": {}, "outputs": [], "source": [ "def flow_mixing3d_exact_u(t, x, y, omega):\n", " return -jnp.tanh((y/2)*jnp.cos(omega*t) - (x/2)*jnp.sin(omega*t))\n", "\n", "\n", "# 3d time-dependent flow-mixing parameters\n", "def flow_mixing3d_params(t, x, y, v_max = 0.385,require_ab=False):\n", " \n", " # t, x, y must be meshgrid\n", " r = jnp.sqrt(x**2 + y**2)\n", " v_t = ((1/jnp.cosh(r))**2) * jnp.tanh(r)\n", " omega = (1/r)*(v_t/v_max)\n", " a, b = None, None\n", " if require_ab:\n", " a = -(v_t/v_max)*(y/r)\n", " b = (v_t/v_max)*(x/r)\n", " return omega, a, b\n", "\n", "\n", "def train_generator_flow_mixing3d(nc,key):\n", " v_max = 0.385\n", " keys = jax.random.split(key, 3)\n", " # collocation points\n", " tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=4.)\n", " xc = jax.random.uniform(keys[1], (nc, 1), minval=-4., maxval=4.)\n", " yc = jax.random.uniform(keys[2], (nc, 1), minval=-4., maxval=4.)\n", " tc_mesh, xc_mesh, yc_mesh = jnp.meshgrid(tc.ravel(), xc.ravel(), yc.ravel(), indexing='ij')\n", "\n", " _, a, b = flow_mixing3d_params(tc_mesh, xc_mesh, yc_mesh, v_max=0.385, require_ab=True)\n", "\n", " # initial points\n", " ti = jnp.zeros((1, 1))\n", " xi = xc\n", " yi = yc\n", " ti_mesh, xi_mesh, yi_mesh = jnp.meshgrid(ti.ravel(), xi.ravel(), yi.ravel(), indexing='ij')\n", " omega_i, _, _ = flow_mixing3d_params(ti_mesh, xi_mesh, yi_mesh, v_max=0.385)\n", " ui = flow_mixing3d_exact_u(ti_mesh, xi_mesh, yi_mesh, omega_i)\n", " # boundary points (hard-coded)\n", " tb = [tc, tc, tc, tc]\n", " xb = [jnp.array([[-4.]]), jnp.array([[4.]]), xc, xc]\n", " yb = [yc, yc, jnp.array([[-4.]]), jnp.array([[4.]])]\n", " ub = []\n", " for i in range(4):\n", " tb_mesh, xb_mesh, yb_mesh = jnp.meshgrid(tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), indexing='ij')\n", " omega_b, _, _ = flow_mixing3d_params(tb_mesh, xb_mesh, yb_mesh, v_max=0.385)\n", " ub += [flow_mixing3d_exact_u(tb_mesh, xb_mesh, yb_mesh, omega_b)]\n", " return tc, xc, yc, ti, xi, yi, ui, tb, xb, yb, ub, a, b" ] }, { "cell_type": "code", "execution_count": 8, "id": "24d0dbc3", "metadata": {}, "outputs": [], "source": [ "def test_generator_flow_mixing3d(nc_test):\n", " v_max = 0.385\n", " t = jnp.linspace(0, 4, nc_test)\n", " x = jnp.linspace(-4, 4, nc_test)\n", " y = jnp.linspace(-4, 4, nc_test)\n", " t = jax.lax.stop_gradient(t)\n", " x = jax.lax.stop_gradient(x)\n", " y = jax.lax.stop_gradient(y)\n", " tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')\n", "\n", " omega, _, _ = flow_mixing3d_params(tm, xm, ym, v_max)\n", " u_gt = flow_mixing3d_exact_u(tm, xm, ym, omega)\n", " t = t.reshape(-1, 1)\n", " x = x.reshape(-1, 1)\n", " y = y.reshape(-1, 1)\n", " return t, x, y,tm,xm,ym ,u_gt" ] }, { "cell_type": "code", "execution_count": 9, "id": "586fb65d", "metadata": {}, "outputs": [], "source": [ "def plotter(xm,ym,zm,u):\n", " fig = plt.figure(figsize=(6, 6))\n", " ax = fig.add_subplot(111, projection='3d')\n", " im = ax.scatter(xm,ym,zm, c=u, s=0.5, cmap='magma')\n", " ax.set_title('U(t, x, y)', fontsize=20)\n", " ax.set_xlabel('t', fontsize=18, labelpad=10)\n", " ax.set_ylabel('x', fontsize=18, labelpad=10)\n", " ax.set_zlabel('y', fontsize=18, labelpad=10)\n", " fig.colorbar(im,ax=ax)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "016a8901", "metadata": {}, "outputs": [], "source": [ "def main(mode,NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):\n", " # force jax to use one device\n", " os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", " os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n", "\n", " # random key\n", " key = jax.random.PRNGKey(SEED)\n", " key, subkey = jax.random.split(key, 2)\n", "\n", " # feature sizes\n", " feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))\n", "\n", " #model = RBFPINN(FEATURES,linear)\n", " # make & init model\n", " if mode == \"CPPINN\":\n", " model = CPPINN(feat_sizes)\n", " elif mode == \"TTPINN\":\n", " model = TTPINN(feat_sizes)\n", " elif mode == \"TuckerPINN\":\n", " model = TuckerPINN(feat_sizes)\n", " params = model.init(subkey, jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)), jax.random.uniform(key,(NC, 1)))\n", " # optimizer\n", " optim = optax.adam(LR)\n", " state = optim.init(params)\n", "\n", " # dataset\n", " key, subkey = jax.random.split(key, 2)\n", " train_data = train_generator_flow_mixing3d(NC, subkey)\n", " print(len(train_data))\n", " #x, y, z, xm,ym,zm,u_gt = _test_generator_helmholtz3d(4,4,3,NC_TEST)\n", " t,x,y,tm,xm,ym, u_gt = test_generator_flow_mixing3d(NC_TEST)\n", " #print(t,x,y)\n", " logger =[]\n", "\n", " # forward & loss function\n", " apply_fn = jax.jit(model.apply)\n", " #print(len(*train_data))\n", " loss_fn = loss_flowmixer(apply_fn, *train_data)\n", "\n", " @jax.jit\n", " def train_one_step(params, state):\n", " # compute loss and gradient\n", " loss, gradient = value_and_grad(loss_fn)(params)\n", " # update state\n", " params, state = update_model(optim, gradient, params, state)\n", " return loss, params, state\n", " \n", " start = time.time()\n", " for e in trange(1, EPOCHS+1):\n", " # single run\n", " loss, params, state = train_one_step(params, state)\n", " if e % LOG_ITER == 0 or e == 1:\n", " u = apply_fn(params, t,x,y)\n", " error = relative_l2(u, u_gt)\n", " print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n", " logger.append([e,loss,error])\n", " #print('Solution:')\n", " #u = apply_fn(params,t,x,y)\n", " #plotter(tm,xm,ym,u)\n", " \n", " end = time.time()\n", " print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n", "\n", " #print('Solution:')\n", " #u = apply_fn(params,t,x,y)\n", " #plotter(tm,xm,ym,u)\n", " return logger" ] }, { "cell_type": "code", "execution_count": null, "id": "f2400611", "metadata": {}, "outputs": [], "source": [ "num = 128\n", "for model in [\"TuckerPINN\"]: \n", " for rank in [64,128]:\n", " for run in range(10):\n", " logs = main(mode=model,NC=num, NI=num, NB=num, NC_TEST=128, SEED=444+run, LR=1e-3, EPOCHS=80000, N_LAYERS=4, FEATURES=rank, LOG_ITER=5000)\n", " logs = np.array(logs)\n", " if np.min(logs[:,2])< 0.1:\n", " np.savetxt(\"Rank{}/model_{}_run_{}\".format(rank,model,run),logs)" ] }, { "cell_type": "code", "execution_count": null, "id": "a9691b60", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "4d3a31d7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-06-24 10:40:44.803880: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] } ], "source": [ "t, x, y,tm,xm,ym ,u_gt= test_generator_flow_mixing3d(100)" ] }, { "cell_type": "code", "execution_count": null, "id": "fc360bbb", "metadata": {}, "outputs": [], "source": [ "fig,ax = plt.subplots(2,2,figsize=(6,6))\n", "im=ax[0,0].imshow(u_gt[0,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[0,0].set_title(\"t=0\",fontsize=8)\n", "ax[0,0].tick_params(axis='x', labelsize=8)\n", "ax[0,0].tick_params(axis='y', labelsize=8)\n", "\n", "ax[0,1].imshow(u_gt[50,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[0,1].set_title(\"t=2\",fontsize=8)\n", "ax[0,1].tick_params(axis='x', labelsize=8)\n", "ax[0,1].tick_params(axis='y', labelsize=8)\n", "\n", "ax[1,0].imshow(u_gt[75,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[1,0].set_title(\"t=3\",fontsize=8)\n", "ax[1,0].tick_params(axis='x', labelsize=8)\n", "ax[1,0].tick_params(axis='y', labelsize=8)\n", "\n", "ax[1,1].imshow(u_gt[-1,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", "ax[1,1].set_title(\"t=4\",fontsize=8)\n", "ax[1,1].tick_params(axis='x', labelsize=8)\n", "ax[1,1].tick_params(axis='y', labelsize=8)\n", "\n", "fig.colorbar(im, ax=ax.ravel().tolist(), shrink=0.8)" ] }, { "cell_type": "code", "execution_count": null, "id": "50482884", "metadata": {}, "outputs": [], "source": [ "from celluloid import Camera" ] }, { "cell_type": "code", "execution_count": null, "id": "b370cd89", "metadata": {}, "outputs": [], "source": [ "fig = plt.figure()\n", "camera = Camera(fig)" ] }, { "cell_type": "code", "execution_count": 11, "id": "d9d7fe49", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'camera' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[11], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(u_gt[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,:,:]\u001b[38;5;241m.\u001b[39mT,cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmagma\u001b[39m\u001b[38;5;124m'\u001b[39m, extent\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m4\u001b[39m,\u001b[38;5;241m4\u001b[39m])\n\u001b[1;32m 8\u001b[0m plt\u001b[38;5;241m.\u001b[39msavefig(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msnaps/snap20\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(i))\n\u001b[0;32m----> 9\u001b[0m camera\u001b[38;5;241m.\u001b[39msnap()\n", "\u001b[0;31mNameError\u001b[0m: name 'camera' is not defined" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAGiCAYAAACyKVKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEvklEQVR4nO2df6wc1Xn3n9ndu7v359rXFwyWr7Fj0pAUOWnsBtklLS7Uxc1LADW8RW/qEBoiOTEIy1JKDBWYFnTTFvUXbVxoIhu1ISYRsZ2+CsiWik0ky8ImtqBUIPHrvQZj8MX27r17787uzs77h+M75zxz58ye3Z3dM+vvR7rSnnvOzJyZnd2z83yfH5brui4BAAAAhpHo9AQAAACAucACBQAAwEiwQAEAADASLFAAAACMBAsUAAAAI8ECBQAAwEiwQAEAADASLFAAAACMBAsUAAAAI8ECBQAAwEjatkCNjY2RZVm0efPmdh0SAABAjGnLAnXkyBF68sknacWKFe04HAAAgC4g8gVqamqKvvrVr9K//du/0fz586M+HAAAgC4hFfUBNm3aRF/60pfohhtuoEceeUQ51rZtsm17tl2r1ejMmTO0YMECsiwr6qkCAABoMa7r0uTkJC1atIgSCb1nokgXqF27dtGvfvUrOnLkSF3jx8bG6OGHH45ySgAAADrAiRMnaPHixVrbRLZAnThxgu69917at28fZbPZurbZunUrbdmyZbadz+dpyZIl9M67P6Ghob6opgoAACAiCoVpWrb0f9Pg4KD2tlZUBQv37NlDt956KyWTydn/OY5DlmVRIpEg27alvrkoFAqUy+Xo4zP/l4aG+qOYJgAAgAgpFIq0YPh/UT6fp6GhIa1tI3uCuv766+nVV1+V/nfnnXfSVVddRffdd1/o4gQAAODiJrIFanBwkK6++mrpf/39/bRgwQLf/wEAAAAOMkkAAAAwksjdzEUOHDjQzsMBAACIMXiCAgAAYCRYoAAAABgJFigAAABG0lYNqmFqtfN/AAAA4kUT3914ggIAAGAkWKAAAAAYCRYoAAAARoIFCgAAgJFggQIAAGAkWKAAAAAYSTzczN3a+T8AAADxoonvbjxBAQAAMBIsUAAAAIwECxQAAAAjwQIFAADASLBAAQAAMBIsUAAAAIwECxQAAAAjiUccFMptAABAPEG5DQAAAN0GFigAAABGggUKAACAkcRCg7JqNbKgQQEAQOxo5rsbT1AAAACMBAsUAAAAI4mFiQ/UAUyg7SWB33YARA0+ZQAAAIwECxQAAAAjwQIFAADASOKhQdVcaCzALHA/AlAfNbfhTfEEBQAAwEiwQAEAADASLFAAAACMJCYaFMptAABALEGqIwAAAN1GpAvU9u3bacWKFTQ0NERDQ0O0evVqeu6556I8JAAAgC4hUhPf4sWL6Xvf+x5deeWVRET01FNP0c0330zHjh2j3/zN36x/R27t/B8AAIB40cR3d6QL1E033SS1H330Udq+fTsdPnxYb4ECAABw0dE2JwnHceinP/0pFYtFWr169ZxjbNsm27Zn24VCoV3TAwAAYBiRO0m8+uqrNDAwQJlMhjZu3Ei7d++mz3zmM3OOHRsbo1wuN/s3Ojoa9fQAAAAYiuW6buN5KOqgXC7T+Pg4nTt3jp599ln6wQ9+QAcPHpxzkZrrCWp0dJTOvPs0DQ31RTlNAAAAEVAoTNPw0v9D+XyehoaGtLaN3MSXTqdnnSRWrVpFR44coX/8x3+kJ554wjc2k8lQJpOJekoAAABiQNvjoFzXlZ6SAAAAgLmI9Anq/vvvp/Xr19Po6ChNTk7Srl276MCBA/T8889HeVgAAABdQKQL1IcffkgbNmygDz74gHK5HK1YsYKef/55+oM/+AO9HblIdQQAALHE1DioH/7wh1HuHgAAQBeDXHwAAACMBNnMgQyuc30k8NsOgLpANnMAAADdBhYoAAAARoIFCgAAgJFAg4qauM4bqGnmfYV+BS4moEEBAADoNrBAAQAAMBIsUAAAAIwkHhqUUz3/B7qHWqRVXhojYbXnOE4T+pWF35SA0UQqobbQxHc37nYAAABGggUKAACAkWCBAgAAYCTx0KBqrpmahUFYiLdqnggvoduy2Cen/qHt0tS6HXz3NEcT1w9PUAAAAIwECxQAAAAjiYWJz6o6ZFXhZt4RYDqsjxATXsMm2GZMgxqHbJ0Jsn3ArB0PLJj4AAAAdBtYoAAAABgJFigAAABGEgsNKtblNjimpyXRBS6452nl+yqmM9K575vQkaDnNAiuWzhN+A/gCQoAAICRYIECAABgJFigAAAAGElMNCiHyNFI8dItmGjfNmFOUelerUwNpKMH+cY2mM5It4zHxVS6I+7ab5y13lrj390X0R0KAAAgTmCBAgAAYCTxMPFVq025KnYNnXjMj8qkZ6LJRccSEWoeYztTmvz4WIWpke9HdRlDzYyKEzYhE7opZi0TzNrtIopzhZs5AACAbgMLFAAAACPBAgUAAMBIYqJBOef/LjbapdO0ytbfhP06bql2tMtTqMZzvUe81X1al4a2lQi5pqpt4/V2+GnX/WSKTtYJ6r3GTXx34wkKAACAkWCBAgAAYCRYoAAAABhJpBrU2NgY/exnP6PXX3+dent7ac2aNfTXf/3X9KlPfUpvR04XpTrqgNbSlL6jY2Nv5jg6elun7P6CVhQaJcS1Iw0NSta3QjQn1XWrhfz+TCiuowlxUBwDtNJQTIzva5RWXe8mvrsjfYI6ePAgbdq0iQ4fPkz79++narVK69ato2KxGOVhAQAAdAGRPkE9//zzUnvHjh106aWX0ssvv0y/+7u/G+WhAQAAxJy2upnn83kiIhoeHp6z37Ztsm17tl0oFIiIyKo6ZJmU6sgE11ITzWlhc2rVvky4/mEmMG6KE8eHuKhb4tgwU6Ew1uf6zk14Pnd24Rrz4+jcXk1U8m2Zua0Z01or7ycDwyU6HcJhxcHN3HVd2rJlC1177bV09dVXzzlmbGyMcrnc7N/o6Gi7pgcAAMAw2rZA3X333fTKK6/Qj3/848AxW7dupXw+P/t34sSJdk0PAACAYbTFxHfPPffQz3/+c3rxxRdp8eLFgeMymQxlMpl2TAkAAIDhRLpAua5L99xzD+3evZsOHDhAy5Yta2xHjpDqCG6cAfsyUFdSukCHzEE8bujYiFzhtariWiHtRHAf139Syfr2w/otDb2KiGlWKpfzueYholvJt1HaFPLQtrCMZjBQ6wqkCTfzSBeoTZs20dNPP0179+6lwcFBOnXqFBER5XI56u3tjfLQAAAAYk6kGtT27dspn8/TddddR5dffvns3zPPPBPlYQEAAHQBkZv4AAAAgEaIR7mNSoWokgwf1y5MtDPrzKlV2lDYtr62xrauQoNqRpNqle0+RN9RtlWaExFRStSG2NgkH6vQq/h+WeojS9SdWGkOX0yV2N1M3FMYGu+PllYUlX7VKk3chNi+MBr97FQqDR8SyWIBAAAYCRYoAAAARoIFCgAAgJHEQ4MqV4l6DMrFp6KVcVrtyl2no9motKKw/Ypt1ufyfF1iP3e2CZuT4jjqc20iFihMgxJ1JaYjWSmFJsX7enqC98s1p6pCryKStSQ+J2Upjg7F4DSjLbYq/6QpMVJxioOKQy4+AAAAQAcsUAAAAIwkHia+SrUpV8W20sxjfKdSEOmYvRSu4y7v4ylwVOmLqsFzcnkfN/mxfreqYeITNvXNXwPLZ+JjAwRTnJXmZjtmeksL7bT8EbV62OdA7O9JBfcRETnsOKJZL8WuMTcHiq7x/NZrpdt5q1y8mwlxUNHKkjIiUaZw67QLewUmPgAAAF0GFigAAABGggUKAACAkcRHgyoblOqI0y67c6Mu0iH2eFfllq3j4q3QkYiYlsT1qapirKqPiNwK7xf1K/kwrsPHivNlYzXeHp69iP/0Swje4VZa1qt4O5Hx7nUry+77XvaRzQrtjOyCbmWZS3omLbfTGlqdKqUSfy99F0ODduhKYWEKzcxJRVSpmXTohB5VaTxECE9QAAAAjAQLFAAAACPBAgUAAMBI4qNBdbLcRktTlLTGZu2L2VHGPWnoSBrxSURErqg/qOKRWD/vcyts27I752siohoLBfK1y55G4lQtNla+j6pV7zdazWFjHfn3W81VpDpiJJPy+SSE1EE9aTkuJJmWx/b0eSeU7JVPLjnEUhL1ex9hi42lAVlzsvh7WRU0Kh4z5bK2eM+oUiYREZFwfqr0UETNpRWS7lsNXUknDpDvu1OlXlr1HRRlvFUQKLcBAACg28ACBQAAwEjiYeIrV/1mhRbTTJobCa3KtlFWr63fhVg6d196omAXbiKKzmxne6ahqi3/jnLKcrtsy/dGpeLd1hWW0bvM20L6nwoz6VVYBVpHMPG5Iea+JMsGnrS8a5Fm5r9sSnbD7cuUvT5mtssOyGPTQ7Z3jBwzB9osxcwga/dnZl9aDnNJ5/dMWnGv8sznolmvGYtSmFu5lPG+hebAVrmz68wpjHalUIuCMlIdAQAA6DKwQAEAADASLFAAAACMJBYalFuukpvs4Foalf23KZdVhY5EJOtbPtdxxbZhruI8VVDZG18ry0N5CiLHFl8zfccO1pXKZfk2tZmr+ExF1k/sqjd+mmlOJSe4PcM0qCrTmURJrUbMJV0hwxARJS1vQIZpNn1Mk+oXNKlBVl4jl7el9tC5krfdPPkNyEyVpHZqgawFWKI2MJSV+0iDJHc7F7bWLcWho/e4CvdvnVAKnarRrazQrBNGokNUmlOYZh60WRlu5gAAALoMLFAAAACMBAsUAAAAI4mFBkVlhyjZgC+9ASnrlbbk0DgoVWxTyHEUNneVruSLZeJtRfySGLtEpI5f4rFLJVvWkUplrz1TkW/TItOcppmuNCXoTkWmQU2zdEZiu8T6yuwaixoUf+fCbomksOse9rOQV9ToS3rnN5jKSH3DM71Se8G0p0mNFGekvvlT01J7YEbWqNLCCSVC0mFZqhNM83+IJ9TCkjI8rqgqfCfoxDI5TshY1WdJfT5apWuC5lcPjX63dSImCnFQAAAAug0sUAAAAIwkFiY+166Sq+uuyulYRvLgLq2M5KzfZ6bzmRMocKy/Lbz2uZHLu3WYGc+pCCmJSjzlkPyeiWY80YRHpDbjTVblvgIbO1mVj1MUTHVFls28yIp7TgvtMruGZXYtxKaj6XJrCY7bSebD3cN80rMpr93PUnwN9cjtBbZ3LRaWZVvbIltuL7SnpPa8imcSzNRkl/Skz2/eu8ZWWIZy8bA8JVRYtd1WuY5XmVlJ3G9Ihv6WVZxuqUt6E9tq0LKUb+I+S6ioCwAAoMvAAgUAAMBIsEABAAAwkphoUA65VuOuikTUvnT1Id6iShuvQkfybcvdzBW6Et8vL20hjq1VmKt4mVekrT8l0YxPZ/LaRa4jMdfxScE9fJK5iufZHKeYzjQpnF+RXZeZarDOZDPtgetMUvYo9uYkQpIDWWL2HzbW4pqUcI3TTO85m2LttDf2XEV2SZ/kKaGYlrdYSO20oFaU+rIkp1RKivNgE/ZpUmLbV303orIYXHNSuZKzsb7PpK/kjIau1GiZG45uCEq9fTq0qvgu3MwBAAB0G1igAAAAGEmkC9SLL75IN910Ey1atIgsy6I9e/ZEeTgAAABdRKQaVLFYpM9+9rN055130h//8R83vB+35JBLbU51pLGpWldSbOizSau3Fft5vJJfkxK6uDme6UqOoOFwjalSZiXSFaUveNkLX2xTVYhtYn15ppfkhdimAtOcCkxDm2RlPYpC22Z6Qom1q8JFDXvLxVmkWDwP15GS/B8alIU5Vtj7avNUTcJ1mmZa3IwjX+NSrU9q83IiIpekZE0qkxRKy7NYLEqx37mqWCeuSXF09B5RS+IfHsVYl8dBcb1KFSflSxvGP3j1a1ByH9tNM3FRzYyVtmtsM45balyDinSBWr9+Pa1fv77u8bZtk217H4RCoRDFtAAAAMQAozSosbExyuVys3+jo6OdnhIAAIAOYdQCtXXrVsrn87N/J06c6PSUAAAAdAij4qAymQxlMhnf/91yjVzr1wbRiNLF+/QfFcr8emHbCvn0fDUbQvYlmud5qjGmP9QcUVcK7iOSdaYK04IqLAbJZtqRqDPxMhhTVZ4zz9sX15wKVa4zee08ywc4xYQZf6yTd3EqLo9tYmUkBK0oyeKTMgl5jtmk93suwxLq8TaXZRLCcWrsjecyhi3oizw/YLXGz7Um9LFyIayEfaUmt2uuV7ojYcn7TbEy9AsyXukOKyvHSCXS8nWyRJ0pLG9fWL8Ij4tS5dfjsU5if5jmxNrytmGxWeL0dGKm5K6mNKhmpPcIvl9riIMCAADQbWCBAgAAYCSRmvimpqbozTffnG2/8847dPz4cRoeHqYlS5bUvZ/atEs1/mhdB5oVEdhBNY6jdCUPHus38VmBY4mY2c5n4ksEjq2yPoe1y4LpjZv0StVgkx4R0bTQz8ti8FQ7YlmMc8ydnbuOF4TKvT43cma+KTGTTdkNNikkmQt0VjDjDfbw0hby2Fzau6bzWBXZwZT8ZmWTPBWSB79deCVf6Tox82aeVTSeVLjUc1Mo/z2aEK5FJiFX6u1Lydew9yPvDUr2y2+W1ccm2ePdB5bP1qnxm7iJtEJKV3JucuImPf5dI/SHlrkR+kNNfGJJHIWpcE4UprhWyRW6BB23Zjf+RRzpAnX06FFau3btbHvLli1ERHTHHXfQzp07ozw0AACAmBPpAnXdddeR29RjDAAAgIsVaFAAAACMxCg38yCckt+7s1m0HuxqaldYlc2X97lCehl/H2szrUjKosJ0C64z1QQX46oTrDmdb3u3gc01KIdrTlxXEtzMfWUxEqztzZlrTn5txbs4Ps2JCXAVheaUseT59yfl9vyMN+fL+uRrenm2xtrepBdkZFfrwbR8QpmUXOZadON2mLt32WEu90Kp9o9Y2fZTJdmV/+SMt6/TJVbevqLWpD4Stk0zl/r+lKxJDZ7tn32dzeWlvuQg06R6hXaafcXwNEiq1Ecqt3IiOWQjrNyGqDNxzanCxyr0LFUfkawrhehV0umEhpgoNKew77IGvzu1tCzV4cvhY4LAExQAAAAjwQIFAADASLBAAQAAMJJ4aFA2kTNHaQBXUS5AF6WOFGrjDZ6Hy2Ob3OA+xxcHxVIUCW2H6UpcZ6oJOodPc3KC2yU2dpqNnfJpUGJpdh7bFFwmo6DQnIiIpgRdoFST9ZxKSOmVrOXpNEM9smZzaa88/yX93jw+0S8by0f7pqX2JYNeCYo+FvuTSrMUSsn6NQOuNZZtb46LpuXUXx9NySUzhtNe++2U/HF+ryjvl1/joqCfnC7xGDD5uo0UvePM+3hG6usZlq+FlRPer175vbJYWiSlvqtTbkMRu0REUuxTqOZU5mmSBK2Lx5bxbcU4R/m29WtQYlMjzdmc/UH7DSNEX1dRr0blQIMCAADQbWCBAgAAYCSxMPFV7SRVa3531FbGAHNzms5xapLruPqRWTTT+Ux4brBJj0h2T+bmwCozxYmZq7n5j7uO20L/jBNm4mNmI43Kt3L6omCTHpFs1rOJ20pkMuw2HhRMXQuZSW/5oPxmfmqgNPt62TzZfXp4vmzi6xHMVdxbmrv9h2WQF+HFdzO93vmmM/J16euV3dv7Bff2dGKQ7Vl2UR+fkictmvi4C/qELY/9sOSZGi+fykp9g3l5TskF3vytAfbeVUPczkVCquRK6YG4S7oqfVGYSa/M3dBFd3buZs6m7LiBfdxMJ0VLhKQ548gp0xoPg/GNbSo93NzzqFYafw7CExQAAAAjwQIFAADASLBAAQAAMJJYaFAVO0nlFmtQOi7qobqSqEG5wW7lYWN9buasvyLoQdztnutMogbF3crtGtekvP5ph+sUcnvK4a7LlvBa6qKir0yGKn2RbLAXdacaKxHaQywtT5K5RGe9/qUD8rafHpRdpD+54Ozs69wCuS/Vy6rxCrrMTF7+6BSLsjv4tC3PqTLH/XuBLEuL1Jfx/HKzvSyFUla+bpdaU4H7LddyUptrj+8XLWGsfJ14FeOJsrft2RlZg7qkOCm109PC+XB3b1/JCYVAolNuQ1UFl/WHak4235dwyEqIBiW0Xa5L8rEaac/8GpUVPJYR9v2lPG6DiN9zvEq3DniCAgAAYCRYoAAAABgJFigAAABGEgsNqlxOUdn1T7UZeymPOVLh15WCt/WZ2Hlsk9DmZRf4+TguK9Uu2JIrtWDNibfLCs2JiGhGioOS58s1p2KVlXSoiq+55iS3pwUdgJfM4LFOVSFoJMV+R2VZCY0cS5+zSCibcWW/HKPziflyrNO8SzzdKcXKa5RZqqDCGa8ExanCgNT3sS1rUDwllKgZ8ruHl1fP9QhlPWZKUt/8AVkny2S967ZgQI7buqIsx0GdrchzLgjxKWdteQ4z7L0TY9zybL92Udbb+m1BwPKlDWKpj3wl4YWr49OcFGXcfSXemY5Urs35msivOfHyEGIclK+PfV4c4TrVeEYuHvfoiBpU/dr1XONV27ZSb68Xcb5lnh5KAzxBAQAAMBIsUAAAAIwkHia+apLsX7sWN5rBXMekF3Yc3+O2mHGF1OZA2c1c3i836VV5qiO3fhNfVRjrdyuX26Jr+bTDTXjBJj3e5mahGWZ2KQkmmbIbbNLjJNnvqD5WFXdBVu5f0ucdZ+mg7AI9f6QotXv6vONWpuX9nDktm8T+X35o9vXJGdmkV6jy905qSumM+K9CngV+UnDL5eZYfn+NJLzzSWfkazrCsrEvZHP+oMcz1U2WWUgDM6+Jpt9JZr60bfn9qJU9O5jLTG1WmOu4os/nOi5V1A3JMi6a+JirODfb1WzWXxHHss8k+3yImel5eitVtQKe1iwsDZrq+ynMPKiimQoRQduW+XujAZ6gAAAAGAkWKAAAAEaCBQoAAICRxEKDmqmkKOn2hA9k6NhTVXZarXIbxG3HwWN5uiI+3ypPdSSW22B93JW8IvRzzclm9m1Rd+Ju5jPMVXamytvunK/PH5e5bbuCBsWq4rpMg7KE304Z5lY+0COfz6Vy5h1aJJSkGJkna07pfuZSLEyj8LG8o/G8XL7inaLXn2flNPgt0sNuJ3HGSUse7bCNRU3wbFm+79MJeY5imqTckOyC3puR0yQNp+X2kKBBfZTg95M8KdELnd9PvGKzpPH44y7ktlKDCsn/I2obPrdydn9VRDdzpjFxzclWuI5XeB/THkUNimvI7Lq5UsiJ+rvAr2VTICrdO2ysDvVqWzY0KAAAAN0GFigAAABGggUKAACAkcRCg7KrKeqpY6oqu6x/rEZsQIiuJI3l9l/WL8Y68f1wXclXUkOKgwrWp4iIykI/15xKNa4zee1ppjGVmAY17XCdyWvbTDOwXXnjiqA7OcTzwMhkhN9O6YR8bkNpuT2SZiUoer34n94hWXdhchaVz3r6yelCv9T3PisrcabsHZffAr1J+T8DSfla9Ch+ClYVMk2ZvVdFVjK9KKQd6i/LAT1JNoeBHlbWQ5gzn1+ZZw5yxdc8vodpK+K2Oh9KNt7l27I0SZIGpYh7Ot8W0hWFaE5VrkEJ73uVxaw5vC3oTLwETo1dJ/H0/GnPwuKghD5f8iwZnZhO9X7qHiphVxt/DsITFAAAACPBAgUAAMBIsEABAAAwkphoUElK/ToXX6Me9dq5+BR2XVVsk69SM7cdC695nJOvxDvXoAQ9wh8jJR9X1J24jsE1qJIz92siomlm2y8zDUqMlykzDari06DEMu7y2AT7rSTm38sm5TibQRYSN8w0qKE+Lw6Kl9DgN9D0pKfhnC4Fa05E8jXuS8nXIZeSdzzYI89JjH3iGg7XE8WYNovdUT5dUoyNYzpGKsHnwHUxb99JpvPx47YMHtvE2yIOz6/HdKWKd37+uKfgWCdHrsCi1JyIiCpl7/7j5ctVOlOF9anK6/A8nCrNiffraOTNaO86iHOY4dqhBniCAgAAYCRYoAAAABhJLEx8006SLCsZPpCh9zjb+H5Ew0O4m7mQNkXbzdx77XMz96U+El8Hm/R4m/eVfWO5yU8oV+EyEx9zJa8JVznMxJeS3Mzl+fezu3aQuU9nsp5reUIu/uoz70zbnr0wX5F3zE2hIn3MrZyb9PqTcjshmPh4SAA/d5UhuyfBU0IF37lhJRtUpiB2ySlpia/lDXuYO7vkys93lFD/JpZcy7n5j5v4xBIavCpuiZn4hMLEjs1Mb8ykV7bl75qKcF9UWFonlRmvWuN9/PMdHHLilwaC37uw7xzVfjk6Zr16Iwh4yRgd2vIE9f3vf5+WLVtG2WyWVq5cSb/85S/bcVgAAAAxJvIF6plnnqHNmzfTAw88QMeOHaMvfvGLtH79ehofH4/60AAAAGJM5AvU3/3d39E3vvENuuuuu+jTn/40/cM//AONjo7S9u3bfWNt26ZCoSD9AQAAuDiJVIMql8v08ssv03e/+13p/+vWraNDhw75xo+NjdHDDz/s+79dS1KyDjumjgt6mP20dW7mwfv1a1By2+dKLpXqkMdyN/Oy5GYu96k0KF5mgberrC3qTtytnKczqgpu5i5zeeYXzhJqpPM0POmEPLg3KWtQPT3Cvtm2LivRbQv6wrQT7FZOJOswfA5Zpg31Jut3rfXfioKbOdN7skzvSaecwLHcBXqmGqyxOex9TVjydcoIH78+5lLfk2J6m1hrhGtQDF86I9G1XOFWTkTSjc1TG/F0Ro7gSl4tqTWnclm+TmI5kTK7hlxnEjUpv6YcrEmpUqIR6X3ncJTlNkK2rXc/KuyaoRrUxMQEOY5DCxculP6/cOFCOnXqlG/81q1bKZ/Pz/6dOHEiyukBAAAwmLZ48Vns15jrur7/ERFlMhnKZDLtmBIAAADDifQJamRkhJLJpO9p6aOPPvI9VQEAAAAikT5BpdNpWrlyJe3fv59uvfXW2f/v37+fbr755rr3M+MkZkuAN5qARTfrv5YGJW7n6wuOX+A6Eh+r0pn86XLksXIcFOvzxTq5c76eq81LaogalMMs6TX2bvl0J4EEO3exnWRP21yTSjE9yBI0kjke1NkcBVt+SDVyUYPiv+x4fFKKtRNW8A3IZZqUMJaf22CPXD6kl7VFbBbXla/IOaKKVVHTlI+TYifYK8gIgylZ8+vtlct8WFlvv1Yy5Dew78Mk3rgsfZFPZxLKt8yElNAoCTpSiObEr5ukQTEtnMe0iW3/ZzQ4nREfG5YyTSf2Ut6PopPRqlRHXNvVIXIT35YtW2jDhg20atUqWr16NT355JM0Pj5OGzdujPrQAAAAYkzkC9Sf/Mmf0Mcff0x/+Zd/SR988AFdffXV9Itf/IKuuOKKqA8NAAAgxrTFSeLb3/42ffvb3254+5JjkdXA42ZYpUlprKpKrsa2/Jh+k5+Hz8QXkvpINvHJY3lV1kqwpYRsnyu5+Jq5kfvawRnLq8y4ILqVE/lNfvX2NQWzn1ksC7lomkuG3C7iNeeppfh7xU1zYiZxnioorXDf5abDvrRs0hPdzLlb+VlWEfi0LX/cC8Ku+L3Xn5LnNE9w3R/OyPmiMgPy+5zoFY7DbYUc/gERblbuVu6WeFtInVWSuqgyHexKzk16M2XZ9GkzV3LRrGf7QhHqN/Gp2qqUaEThqZBEwkJdVESR6ohfMx2QLBYAAICRYIECAABgJFigAAAAGEksym3M1CwiK3wt5eZsHVSumSodybcfX2oj3h+crigs9ZGkgfBKBAoNiutIPk1K2DHfLy+h4fAKryS6mTONwOJu517bX2JCRtSk+DX06W+sjIFbFffN0/DI22YFTac/xSvQyi7FYkWHIkuZNMXKMPQz7WVQqPqbttRpkJJSpVvmvs5SHVUFfaRQkoPc35vuldoflOQ5FiviceQ5DLHrtDDr6Uzz+2ekvvQ85uafFY7DNShfRV0WiiBqUFxzYiU1RNdyleZERGQL+hvXnGaY+32ZaSZiuQiuOZVZW0xPxkvi+HWl4O+CsCq5qtAWjkqL19OnGkNVtiYMPEEBAAAwEixQAAAAjAQLFAAAACOJhQZlOzSnFbWZyJlmysGr7LbhGpT32m93lmHVBqTxYeU2xLIY/jRIwSU0fHFOrO24PNbJa/My7r5UR4IexHU8vq3rinPicVvBGgERUaUi9NdYjA7TVgb6vZieYRZjNJBiWoWQPidfke+fbFL+KGV8cVBeu4+VqE+z0hy8bIY0B6afTAntkzOy5vT/puWxE6zcvXgPDfTI53NZVn4/FvV6utO8BdNSXzInX38ro/ha4R8QvyA6+9KdYeVbiuzeKwp6zwzTnErsvROuE9ecfGVI2P0kxvHw+DeuM8lxUFLXHHFQ3mtVSjQi9fdIWOyltF1wl49Wafq8vI8OeIICAABgJFigAAAAGEksTHwlx6ILRr5WJcRpxr1SJy2SP0VJ8Fgdt3Nu9vK7XgePVbV5xVyVWzkRkSu6mVs8tZHKzZyUY0XTIZ/vtHwYmmQmGtG8U6vImbaT/cw0N+Tt7NI+2X36Eua2LR6Hz+G0zTNVy2akkmD6GarIG/ex7OBidianxt3Z5XP9WDTxMTfy0yz9T4ndJL1COqPLeuW+Zf2yPXDRvElvuxGWqX1ATqmkTG/Eq+Ta7J4RLqwzxczLRflalIuC6/hMWuqbtuXrPy1cp+kQk57PxFcTXceD3crP9wf3qaoX+Mz9IaEtet9BjcsZOgTN6fz3d2PgCQoAAICRYIECAABgJFigAAAAGEksNKiyS3Qha45uZdwgWpniQ6Ur+dxFJfdQ9Vhe5VRyM/e5oHNdKbiPV8kV0xn5Uhsp3MrPt0XX8WC3cg7XnByW/kcs48HnW2T6T4FVQJ2c8bSjeUXmEt0vb9sz5O370uFJqe8Km2tQ3sYnmRYxyQrbclf4gpB+qZ+5pGeTwXcYDxEoMnt+QXB3982BXbc+Vk/kUsEr/cp+eeNPzMtL7fmLvOuYWiDrO1Yv+xoRRbQQzcllYl6tKJQPmWL37RQrkzEj6EoKzYmIqCi4lk9raE5E8nvpTwXGKvcKbR4m4i+pIfSFVtQNbut8P4URhQZl6/i2M/AEBQAAwEiwQAEAADASLFAAAACMJBYaVMmZO+1GM3qUThqPMBOqnj1YKCMRUl5DpUn506gwvUoVB8V0JTGdEdeceNwTL6lRU6Y6Ym0r+ErWfMf1jlOqyccsstIWZ8ry76wzQqnzkUlZi+jJsbioXs/2379Q1mGusM9Jbbl0tZxW6FRJngOPkxLTvZxhqWkSVnCcSFhKK7Gf72UozTSnrLyzT/R71+I3mOa0cJGsx2Uu8665xWtx+EpqCPd4mOY0KbedSSHF1ZT8Ps8Uma5U8mKfpmw5DqrIYp3E2KdpR10yg8ftVBQlNCqKzzBPbcQ1KXHTsFI7HNW2HJ2YqWaABgUAAOCiAQsUAAAAI8ECBQAAwEhirUFxak3kh2+VJuXLoaVIm8/nG5ZiX9SZwmzWsq7EY6SCy7j745xCSmgIuhKPe+Kak6u4cg7J+k9ZaJdYyYwpFvd0pixrFR8K8UuX5Pukvt5B+TgJIdQpNShrBvMXy7n5foPOzL7OJOdJff1JWZP6kJUcF+OVePkBriup7sUe9pOyX2gPMmno0ox8oCV9cn69pYLudMmiKakvs4jpZDlv51aPfG6+su0VoWQG15zy8vWv5uVt7by37+lJWVcS49uIZN1pkpXQ8GlQgu5UYhqUP+7JYv3ea64rqTRBnXI6ofGTFEyYBtWq8hthQIMCAABw0YAFCgAAgJHEwsRnV11yE805RDaXwkO9daPlN1wKMfEpXMnD0yIJZjufCzd3O/dMMhVmpuNu5VWqsraY6kht0qspUh9ZlvxbSTyu7crHLDK7yses1MUHJe+2XjAl5zYa/Fg2c6WESrGp+fJ+UvNlc85wyhubycpzmn9mQGqfKsqmxdOCOSpfkU1k3KQkuY4z3/Es+xwMprxrMZKRXegv65PTPF0yLJvxBi/zxveMyHNKMFdyqUquwqRHJJv1uBs5N+mVzspfQcUpwXWcmfQKPldyb46TlWCTHpHsOj5T427mxNrBZjwdt/8wM7xO2jO1iU/9/dTKkJpG4GEuOuAJCgAAgJFggQIAAGAkWKAAAAAYSSw0qPM2Yr8ds1WlN87vvf6dKTWnsLQj0li1BuVz6ZZcWF3lWDktktp1XExn5NecmCbFymKIruXcrZxrTlyjkqmwlvfbySbmQuzI7TxLdXRqxtNTcim5HPnAuSGpnUp7cxxMyBpOKsc1KW+/gxlZW8nOPyu1Lzkr6z2TU56eIqboISIqMZdoV0haZLH3NcvKw/dnvTkPDsj6Wna+fE175snaSlJ0He9jXwVJ+dxd8ea02fvqS18kaJoF7kbONSemM0177ckyczNnruSTQsqrGVYyY5qlKxI1KL9budT0l9RQaFCqtj91WfDnmX8ywsJVpLHBXT6aCcVp9PuWX18d8AQFAADASLBAAQAAMBIsUAAAAIwkFhqU7bhN2U7ngus/OqhssVwL8h83eKzP7uyLkxLKGIQcV9SdeMkMvwYllNn2aU5MX1CU1PDHPdWf6shXDls4rk2ytlJ0ZS2iUJZv4wlBP+lPydpEX0qOT0p/7J1vwpJLTvQzXUzUpBL98m+7jLxbSg/L122g5O2rJp8O1Sqs3LdwMVh4GCV65PdZTNWU6JMHJ3pZLFOWpShKCcdNsIArHtskCAm1Ii+Rwcq35IWYI5auqFiUNafJUnCs0yTT5qZYmZVi1TtfX9wTu6FsQYPyxz3JbVWsk09zqvHPnfc6LJappohrDC/bE9zXKj29GcTvMsRBAQAA6DoiXaAeffRRWrNmDfX19dG8efOiPBQAAIAuI1ITX7lcpttuu41Wr15NP/zhDxvfT61GrttcEo4w05sKvVQhYWmRgs10YfsSzZI+k54i67h/bLArOTfp+Ux+zOwluparUhlx/GPrT4NUItnUM1mT22kh9VE6Kff1JJjZi7wURTWWqfqyakFqD9pCaqB5zNTGzGtWr3zcVL+3b7cJk4eVYqY4wTRn8cq2YT8/q8K9WGUmvRnuSu69z9VJef7lSebiLZj1eAbySVtu58vcddz7SuKVk7nruNjmVXB5xnjRjKdyIyciqirMdmGu42oTX3BYiU54CkdHrmjGpNfodyivnqBDpAvUww8/TEREO3fujPIwAAAAuhCjnCRs2ybb9hTkQqGgGA0AAKCbMcpJYmxsjHK53Ozf6Ohop6cEAACgQ2g/QW3btm3WdBfEkSNHaNWqVdqT2bp1K23ZsmW2XSgUaHR0lOyaQzWrfn1jLlrpTamyxYbZg1Xb+jQnbrNW6Er+bWuBfbxkhlgVl2tOYVVyRU1Kp9yGyuWc75eTYL+riuw27hG0iyRzY05azNWaPL2EV0stMQ3kMttLXzRvWq62m8nJ1ynZL+/LygpaESuLa3EXb7Gb93GEfrfKrikvi1FlbaHUaa3EtJQi05mKQkXaKdl1fJqXxRCu+ZRCYyIiKrAyGUVHdB2Xz31GoTOFVSmWNSh2rj5Nin+WhLFcc+KXXNxOM5WZdMxQTar+b7RGI2qa0e1FKk34D2gvUHfffTfdfvvtyjFLly5taDKZTIYymUz4QAAAAF2P9gI1MjJCIyMjUcwFAAAAmCVSJ4nx8XE6c+YMjY+Pk+M4dPz4cSIiuvLKK2lgYEC9MQAAgIuaSBeoBx98kJ566qnZ9m/91m8REdELL7xA1113Xd37qbiOTwtpllbZV4nUqUV0NKewfpWupExB5CuDEbytKs5pzn5lqiOuZ9UfMyXKMPyYFV9clNwuuF7bqnANR9ZPXCGmqlKT9RKePkcs97BwRi7jMT8va1L9rPRFut8732SvrAGy0Cw51knDjcmtsjYL2qmV5H7H9o5TmZH1ttI0uxa2157mZTC4ziRcp0lFeiIi/zUWdScWisWrfEjtCvsoldm5i7qSv0SGWldyFCmJVGVudHSk0LLt6l0p59QozaSDE6m6jX93R+rFt3PnTnJd1/enszgBAAC4ODHKzRwAAAC4gFGBukHYVCWHuHtwtDT6mBzmPq1zDB2znWrffvfv4Eq3/uzl6qq4rso8qDDjhbuZe/1W2O8oize9eyXBXMfdMk8f5ZmrKixlks3ak1XPrHeGmbVGmMlveFI28Q1kvDRJ2bR8jdOsnUp5555I1p8d32WVYqvMnFapyOdTFly8Z1hG+JmqfH5TFTEFkTyWm+lUWcb9ruNSU3IXL3OTHrOvif3cNbyqMNtxk55vrLLybXCFASLZFKdKVcbRNfer0MlmrkOj34k8rEUHPEEBAAAwEixQAAAAjAQLFAAAACOJhQZVoarkDlwvXC9pFc3sV6Ud6WpS0n4VWhF3FVe5g+ukK+L9qj7ffHXSnzCNSZUGiY/3XReefkbwT67UZPdpnuposuL9njuXlj86p8vy2PksxdJgyptHf0q2yfel5DmmE4JLeoKV9bC4BiKkOvJVgk0o2yUnOedrIqISGzst6EqlWnDZi/P7EstgEOuT277qtoIPOO/jOpPY5iUyVLoS15h8JTM0ytyEpSdTja23r55+FTq6eBRAgwIAANB1YIECAABgJFigAAAAGEksNKiyVSYnpOpAlDSlOWmVQQ+JbVLoVyo7s04sk46OxMfzsVxn0ikJL+9I3i7B9EhlTFjIfVMV8gNVnF6pzy7JmtSM48UGTVXk33Zny/KBJnrkOfanvHZ/Uo4xyrJYp4ygOyWZ5sTbIrxkPU/pU2H9FUFLKjNdyebakDRW7vPpSkJbFbt0fg7B2pG/NDtLu6XQlXw6E4lj1am/WlXmhqPWoPS+Y1qlr4fFU7YCh8rhgwLAExQAAAAjwQIFAADASLBAAQAAMJJYaFAVsltaHqMeVHqPDs3EIKg0mzAbtE5pi0bHEsk6k+62cl/wthbTnBy2H8tlv7OkOCimN7D3tSrkHnSY1lVx5fx65bIX28Tz0fWxOKi+lDynrFBCoy8p6z1p3hY2ZV2+tgr+ifGXmfBe+/We4DaPOfKPFeKT3LCxjetKopZU9emqTNtSlKrxx/qpS9mo+pqJc6z3mLq0ulSRLtCgAAAAdB1YoAAAABhJPEx8VplqCvfaKIgqTZJeOY76H8113MxV24alINIpoaEyLdR0Uh2x/SZYRV1+HPEc/CY+eayYNsmxKoF9RERl12tnanIqo+ma/FHKMhNgJuGZANNJef497GdiyrKEPtmml+ClRXRMfrwarND2m8+ItQXTW0gKIpWZrsred55WSDTbcVMuN5GJZjsdM12YWU7HxKe8x3U+601ICp1OZRQGTHwAAAC6DixQAAAAjAQLFAAAACOJhQZVJTsyTageorLxNpz6h8LnpFPOohO6UjOur9wdnLuhi5qVyzSnhCunGRL1B36ujiWXCahYni297LLSHCS3p1l/j5MSXsvz7bF42/vdmGQiU4LlbtLRoDiqUuY+vYpU7t/BGo6OjkTEQx74WPn9EN+7MN1IRytqNKXY+X21LrWZCtN1J5HQ8jgK8AQFAADASLBAAQAAMBIsUAAAAIwkNhpUIiY2V61S5iFEpVFplQAJOR/Vvlp5LdRzYKmPBA2H61WqUiPcVp60ZL1K7K+wviT7KPVQhvV741NMM0u68rYp1+tPsrFcg+JtESuk1ohOeXJVyQlVDJJP3/HpfMHvD7+3lCmHrOD39fy2rUkb5t+2NbGKujTz3dBuatCgAAAAdBtYoAAAABhJLEx8jlttKJe56a6YrcwyHJWLt46ZTud6u250Jgq/27nYF+x+zLfjpglHMNPxqr4JS/4oVSxb7ifRbCeP5abEhJCdPcF+Q/I5iv0qc58uzVR/Fa9pmAt3qypBh21b73Zh2+ruSxrbJpO3aea/mlsNHxQAnqAAAAAYCRYoAAAARoIFCgAAgJHEQoOqujYlDLOrBtGu6pXN2LOb0eZ0tKOOpacSrg3XcPj7I2o6fn1K3lY8H17yw3JlbahKwRoU16u4niW6yfv62Jz4+ang+1LRqpQ9Ou7ec42vd9t2pf7SPW79+4nH91u9iJo4NCgAAABdBxYoAAAARoIFCgAAgJFEpkG9++679Fd/9Vf0X//1X3Tq1ClatGgR/emf/ik98MADlE6nw3cgUHMrRA1FQrUGE+OpWhlH1FTa/xbFdjRzjbkuI8LLPfi3De4X45GIiCyhLEbN1xcSryT0c73KPyeFBmUFn6vqOrQTZZqtFmpBWjFILUzv1arjREW74q3qpRkNKrIF6vXXX6darUZPPPEEXXnllfTf//3f9M1vfpOKxSI99thjUR0WAABAl2C5Li9NFh1/+7d/S9u3b6e33367rvGFQoFyuRyNDH7B5/nUTvAEpZqH2U9Qodsqnkh8WRws0eOviSeoEG86PEGdB09QDc7BwCeoicmXKJ/P09DQkNa2bf3Wz+fzNDw8HNhv2zbZtueeWygUiOjCI2Jz62gnK/JGQVQ3YSsX4yjTGUnH0fhSsFj1WvE68i94bh5s1BxIROQIt2/Y4iYfMzi1URhhC2FUmLgYdCoso+Fjtumz0w5qTZxL235yvfXWW/T444/Txo0bA8eMjY1RLpeb/RsdHW3X9AAAABiG9gK1bds2sixL+Xf06FFpm5MnT9KNN95It912G911112B+966dSvl8/nZvxMnTuifEQAAgK5AW4OamJigiYkJ5ZilS5dSNpslovOL09q1a+maa66hnTt3UiJR/5p4QYMaHvg8JazmzBUw8dW53xia+HTgpjepL+T3WqN6VdhxtEx8irFh27YLmPiax8TPTqPUXIfOTP2qPRrUyMgIjYyM1DX2/fffp7Vr19LKlStpx44dWouTSM2tKqt/XsBEZ4Z20Ta9xzABVhfV/FULBZH6C5+nRSLmWqtc3FzFAqXQsvzzM8NJQod2pd1S0a4fr3H/7DRKM+9TZE4SJ0+epOuuu46WLFlCjz32GJ0+fXq277LLLovqsAAAALqEyBaoffv20ZtvvklvvvkmLV68WOpro2c7AACAmBKZTeDrX/86ua475x8AAAAQRizKbdTcijIGpREuJnuwKdqcCddcpQWFzc/SmL/vOGIclMVLy6uOqeEUoeFAYSpxcP6p+5hd5OjQDM1ch/jf0QAAALoSLFAAAACMBAsUAAAAI4mJBlUlq4PlNprBBN2lGeI+f04zcVAqHYNvq7xuPEZKJ7+eao6aHxFVMHEzmKC9mHjfmqIFtxtoUAAAALoOLFAAAACMJB4mvppDlhVPE1+7MNGkETd0rqGWSS9sWx3TD/sYNOVa3kSlU9Potvu/m86nmXPBExQAAAAjwQIFAADASLBAAQAAMJJ4aFBuJZalBC424mg3b1TD0XWdFUt1tPI6dUN6o3YQx3uzW4AGBQAAoOvAAgUAAMBIsEABAAAwklhoULAfNweuXzDNxC9pHafBNDeh6ZcMSCt0MYPPVjjQoAAAAHQdWKAAAAAYCUx8LeBizVJ8saFlDmzRb78o730TXdRN/6wDfWDiAwAA0HVggQIAAGAkWKAAAAAYSSw0qPMVdbGWgvjQjC7ZLm0Iek98iPN7BQ0KAABA14EFCgAAgJFggQIAAGAksdCgEGcELiYatdlDp60PfJ+0l2auN+5oAAAARoIFCgAAgJHEwsRHMXaxBKBdwHQFjARu5gAAALoNLFAAAACMBAsUAAAAI4mFBuWSSwT7OgAAxI7z39+NgScoAAAARhLpAvXlL3+ZlixZQtlsli6//HLasGEDnTx5MspDAgAA6BIiXaDWrl1LP/nJT+iNN96gZ599lt566y36yle+EuUhAQAAdAmW67qNGwg1+fnPf0633HIL2bZNPT09oeMLhQLlcjmyrCGyLKsNMwQAANBKXNcl1y1QPp+noaEhrW3b5iRx5swZ+tGPfkRr1qwJXJxs2ybbtmfbhUKhXdMDAABgGJE7Sdx3333U399PCxYsoPHxcdq7d2/g2LGxMcrlcrN/o6OjUU8PAACAoWgvUNu2bSPLspR/R48enR3/ne98h44dO0b79u2jZDJJX/va1yjIqrh161bK5/OzfydOnGj8zAAAAMQabQ1qYmKCJiYmlGOWLl1K2WzW9//33nuPRkdH6dChQ7R69erQY3ka1AA0KAAAiCHnNaip9mhQIyMjNDIyorsZEdHsk5OoMwEAAABzEZmTxEsvvUQvvfQSXXvttTR//nx6++236cEHH6Tly5fX9fQEAADg4iayBaq3t5d+9rOf0UMPPUTFYpEuv/xyuvHGG2nXrl2UyWS09nW+wihMfAAAEDeaiWRqaxyULhc0KKI+aFAAABBDzi8x0w1pUMjFBwAAwEiwQAEAADCSWJTbOF9qAyY+ANTg92b7QRmgcFBuAwAAQJeBBQoAAICRYIECAABgJDHRoAAA4UAPAd0FnqAAAAAYCRYoAAAARhITE1+NyIWbOQAAxA+4mQMAAOgysEABAAAwEixQAAAAjCQmGpTbhBUTAABA54AGBQAAoMvAAgUAAMBIsEABAAAwEixQAAAAjAQLFAAAACPBAgUAAMBIYuJmjoq6AAAQT+BmDgAAoMvAAgUAAMBIsEABAAAwEixQAAAAjAQLFAAAACPBAgUAAMBIsEABAAAwkpjEQaHYBgAAxBPEQQEAAOgysEABAAAwEixQAAAAjAQLFAAAACPBAgUAAMBIsEABAAAwEixQAAAAjKQtC5Rt2/S5z32OLMui48ePt+OQAAAAYk5bFqg///M/p0WLFrXjUAAAALqEyDNJPPfcc7Rv3z569tln6bnnnlOOtW2bbNuebefz+V+/QiYJAACIJ+e/v11X/3s80gXqww8/pG9+85u0Z88e6uvrCx0/NjZGDz/88Bw9LmGRAgCA+PLxxx9TLpfT2sZyG1nW6sB1XfqjP/oj+p3f+R36i7/4C3r33Xdp2bJldOzYMfrc5z435zb8CercuXN0xRVX0Pj4uPaJdZJCoUCjo6N04sQJGhoa6vR0tIjr3DHv9oJ5t5+4zj2fz9OSJUvo7NmzNG/ePK1ttZ+gtm3bFvCU43HkyBE6dOgQFQoF2rp1a937zmQylMlkfP/P5XKxekMuMDQ0FMt5E8V37ph3e8G8209c555I6Ls8aC9Qd999N91+++3KMUuXLqVHHnmEDh8+7FtwVq1aRV/96lfpqaee0j00AACAiwjtBWpkZIRGRkZCx/3TP/0TPfLII7PtkydP0h/+4R/SM888Q9dcc43uYQEAAFxkROYksWTJEqk9MDBARETLly+nxYsX17WPTCZDDz300JxmP5OJ67yJ4jt3zLu9YN7tJ65zb2bekTlJcOpxkgAAAAAu0LYFCgAAANABufgAAAAYCRYoAAAARoIFCgAAgJFggQIAAGAksVug4li648tf/jItWbKEstksXX755bRhwwY6efJkp6el5N1336VvfOMbtGzZMurt7aXly5fTQw89ROVyudNTC+XRRx+lNWvWUF9fn3ZqlXby/e9/n5YtW0bZbJZWrlxJv/zlLzs9pVBefPFFuummm2jRokVkWRbt2bOn01Oqi7GxMfrt3/5tGhwcpEsvvZRuueUWeuONNzo9rVC2b99OK1asmM0esXr16tCk2yYyNjZGlmXR5s2btbaL3QIVx9Ida9eupZ/85Cf0xhtv0LPPPktvvfUWfeUrX+n0tJS8/vrrVKvV6IknnqDXXnuN/v7v/57+9V//le6///5OTy2UcrlMt912G33rW9/q9FQCeeaZZ2jz5s30wAMP0LFjx+iLX/wirV+/nsbHxzs9NSXFYpE++9nP0j//8z93eipaHDx4kDZt2kSHDx+m/fv3U7VapXXr1lGxWOz01JQsXryYvve979HRo0fp6NGj9Pu///t0880302uvvdbpqdXNkSNH6Mknn6QVK1bob+zGiF/84hfuVVdd5b722msuEbnHjh3r9JQaYu/eva5lWW65XO70VLT4m7/5G3fZsmWdnkbd7Nixw83lcp2expx84QtfcDdu3Cj976qrrnK/+93vdmhG+hCRu3v37k5PoyE++ugjl4jcgwcPdnoq2syfP9/9wQ9+0Olp1MXk5KT7yU9+0t2/f7/7e7/3e+69996rtX1snqAulO7493//97pKd5jKmTNn6Ec/+hGtWbOGenp6Oj0dLfL5PA0PD3d6GrGnXC7Tyy+/TOvWrZP+v27dOjp06FCHZnVxcaHWXJzuZ8dxaNeuXVQsFmn16tWdnk5dbNq0ib70pS/RDTfc0ND2sVigXNelr3/967Rx40ZatWpVp6fTEPfddx/19/fTggULaHx8nPbu3dvpKWnx1ltv0eOPP04bN27s9FRiz8TEBDmOQwsXLpT+v3DhQjp16lSHZnXx4Loubdmyha699lq6+uqrOz2dUF599VUaGBigTCZDGzdupN27d9NnPvOZTk8rlF27dtGvfvUrGhsba3gfHV2gtm3bRpZlKf+OHj1Kjz/+uHbpjqipd+4X+M53vkPHjh2jffv2UTKZpK997WsNVZhs97yJzif6vfHGG+m2226ju+66q+1zbnTepmNZltR2Xdf3P9B67r77bnrllVfoxz/+caenUhef+tSn6Pjx43T48GH61re+RXfccQf9z//8T6enpeTEiRN077330n/8x39QNptteD8dTXU0MTFBExMTyjFLly6l22+/nf7zP/9T+vA6jkPJZLJjpTvqnftcb857771Ho6OjdOjQobY/quvO++TJk7R27Vq65ppraOfOnQ3VdGkFjVzvnTt30ubNm+ncuXMRz06PcrlMfX199NOf/pRuvfXW2f/fe++9dPz4cTp48GAHZ1c/lmXR7t276ZZbbun0VOrmnnvuoT179tCLL75Iy5Yt6/R0GuKGG26g5cuX0xNPPNHpqQSyZ88euvXWWymZTM7+z3EcsiyLEokE2bYt9QURacn3MOJcuqPeuc/Fhd8EYvXgdqEz7/fff5/Wrl1LK1eupB07dnRscSJq7nqbRjqdppUrV9L+/fulBWr//v108803d3Bm3YvrunTPPffQ7t276cCBA7FdnIjOn0snvjt0uP766+nVV1+V/nfnnXfSVVddRffdd19dixNRhxeoemlF6Y5O8dJLL9FLL71E1157Lc2fP5/efvttevDBB2n58uVGC50nT56k6667jpYsWUKPPfYYnT59erbvsssu6+DMwhkfH6czZ87Q+Pg4OY4zGy935ZVXzt47nWbLli20YcMGWrVqFa1evZqefPJJGh8fN17jm5qaojfffHO2/c4779Dx48dpeHjY9zk1iU2bNtHTTz9Ne/fupcHBwVmtL5fLUW9vb4dnF8z9999P69evp9HRUZqcnKRdu3bRgQMH6Pnnn+/01JQMDg769L0LGryW7tdSn8I28c4778TGzfyVV15x165d6w4PD7uZTMZdunSpu3HjRve9997r9NSU7NixwyWiOf9M54477phz3i+88EKnpybxL//yL+4VV1zhptNp9/Of/3wsXJ5feOGFOa/tHXfc0empKQm6l3fs2NHpqSn5sz/7s9l75JJLLnGvv/56d9++fZ2eVkM04maOchsAAACMJBZu5gAAAC4+sEABAAAwEixQAAAAjAQLFAAAACPBAgUAAMBIsEABAAAwEixQAAAAjAQLFAAAACPBAgUAAMBIsEABAAAwEixQAAAAjOT/A7FI7ZkT6ddTAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for i in range(20):\n", " k = i*5 \n", " plt.imshow(u_gt[k,:,:].T,cmap='magma',extent=[-4,4,-4,4])\n", " #plt.show()\n", " plt.savefig(\"snaps/snap{}\".format(i))\n", " #camera.snap()\n", "plt.imshow(u_gt[-1,:,:].T,cmap='magma', extent=[-4,4,-4,4])\n", "plt.savefig(\"snaps/snap20\".format(i))\n", "#camera.snap()" ] }, { "cell_type": "code", "execution_count": null, "id": "01b762ae", "metadata": {}, "outputs": [], "source": [ "animation = camera.animate()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a04f0374", "metadata": {}, "outputs": [], "source": [ "animation" ] }, { "cell_type": "code", "execution_count": null, "id": "570e6f71", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d5605bc6", "metadata": {}, "outputs": [], "source": [ "#writergif = animation.PillowWriter(fps=30)\n", "animation.save('flow.gif',writer='imagemagick')" ] }, { "cell_type": "code", "execution_count": null, "id": "19a7e5a5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jaxdf", "language": "python", "name": "jaxdf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }