Skip to content

Instantly share code, notes, and snippets.

@romanodev
Created July 17, 2020 04:52
Show Gist options
  • Save romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f to your computer and use it in GitHub Desktop.
Save romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f to your computer and use it in GitHub Desktop.
JAX GMRES while_loop
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX GMRES while_loop",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/romanodev/be02bd4b7e90c5ebb3dc84ebebf4e76f/jax-gmres-loops.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4LEr98cuYQf",
"colab_type": "text"
},
"source": [
"# Simple JAX GMRES\n",
"\n",
"Readapted from: https://gist.github.com/shoyer/dc33a5850337b6a87d48ed97b4727d29 (Author: shoyer@google.com)\n",
"\n",
"Date: July 17, 2020\n",
"\n",
"Modificationas by: romanog@mit.edu July 12, 2020\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "SR3HPqI2q8Th",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"import numpy as np\n",
"import functools\n",
"from jax import random\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.ops\n",
"import jax.scipy as jsp\n",
"from jax.tree_util import Partial\n",
"import scipy.sparse.linalg\n",
"from jax.experimental import loops\n",
"from jax.experimental import host_callback as hcb\n",
"\n",
"def _identity(x):\n",
" return x\n",
"\n",
"_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST)\n",
"\n",
"\n",
"@jax.jit\n",
"def lstsq(a, b):\n",
"\n",
" return jax.numpy.linalg.lstsq(a,b)\n",
"\n",
"\n",
"def _gmres3(A, b, x0, n, M=_identity, tol=1e-5, atol=1e-5):\n",
"\n",
" # tolerance handling uses the \"non-legacy\" behavior of scipy.sparse.linalg.cg\n",
"\n",
" beta_e1 = jnp.linalg.norm(b - A(x0))*jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n",
" m = b.shape[0]\n",
" q = b / jnp.linalg.norm(b)\n",
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n",
" H = jnp.zeros((n,n+1))\n",
"\n",
"\n",
" def cond_fun(value):\n",
"\n",
" x,Q,H,x0,r,k = value\n",
" \n",
" return (r > atol) & (k < n)\n",
"\n",
" def body_fun(value):\n",
"\n",
" x,Q,H,x0,r,k = value\n",
"\n",
" q = Q[:, k]\n",
" v = A(M(q))\n",
" h = _dot(Q.T.conj(), v)\n",
" v = v - _dot(Q, h)\n",
" v_norm = jnp.linalg.norm(v)\n",
" Q = Q.at[:, k+1].set(v / v_norm)\n",
" h = h.at[k+1].set(v_norm)\n",
"\n",
" H = jax.ops.index_update(H,jax.ops.index[k,:], h)\n",
" \n",
" y,r,_,_ = lstsq(H.T, beta_e1) \n",
" \n",
" x = x0 + M(_dot(Q[:,:-1], y))\n",
"\n",
" \n",
" return x,Q,H,x0,r[0],k+1\n",
"\n",
" x = jnp.zeros_like(b)\n",
" initial_value = (x,Q,H,x0,1.0,0)\n",
"\n",
" x, *_ = lax.while_loop(cond_fun, body_fun, initial_value)\n",
"\n",
" return x\n",
"\n",
"\n",
"\n",
"def _gmres2(A, b, x0, n, M):\n",
" \n",
" beta_e1 = jnp.linalg.norm(b - A(x0))*jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n",
" m = b.shape[0]\n",
" q = b / jnp.linalg.norm(b)\n",
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n",
" x0 = jnp.zeros_like(b)\n",
" H = jnp.zeros((n,n+1))\n",
" \n",
" with loops.Scope() as s:\n",
"\n",
" s.x = jnp.zeros_like(b)\n",
"\n",
" for k in s.range(n):\n",
"\n",
" q = Q[:, k]\n",
" v = A(M(q))\n",
" h = _dot(Q.T.conj(), v)\n",
" v = v - _dot(Q, h)\n",
" v_norm = jnp.linalg.norm(v)\n",
" Q = Q.at[:, k+1].set(v / v_norm)\n",
" h = h.at[k+1].set(v_norm)\n",
" H = jax.ops.index_update(H,jax.ops.index[k,:], h)\n",
" y,r,_,_ = lstsq(H[0:k+1,:].T, beta_e1) \n",
" s.x = x0 + M(_dot(Q[:,:k+1], y))\n",
" \n",
" return s.x\n",
" \n",
"\n",
"def gmres(A, b, x0=None, n=5, M=None):\n",
"\n",
" if x0 is None:\n",
" x0 = jnp.zeros_like(b)\n",
" if M is None:\n",
" M = _identity\n",
" \n",
" return _gmres3(A, b, x0, n, M)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j5REsLhwxfsc",
"colab_type": "text"
},
"source": [
"## Tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVgN9XUExrtj",
"colab_type": "text"
},
"source": [
"Verify correctness:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pcIJkLIKX7cK",
"colab_type": "code",
"colab": {}
},
"source": [
"A = random.normal(random.PRNGKey(0), (100, 100))\n",
"b = random.normal(random.PRNGKey(1), (100,))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dtJZvMbtq9nH",
"colab_type": "code",
"colab": {}
},
"source": [
"#M = jnp.diag(1/jnp.diag(A))\n",
"M = jnp.eye(100) #no preconditioning\n",
"\n",
"np.testing.assert_allclose(\n",
" gmres(functools.partial(jnp.dot, A), b, n=30,M=functools.partial(jnp.dot,M)),\n",
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=30, maxiter=1,M=np.asarray(M))[0],\n",
" atol=1e-5,rtol=1e-4,\n",
")"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment