Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created July 7, 2020 21:11
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/cbac2cf8c8675b2f3a45e4837e3bed80 to your computer and use it in GitHub Desktop.
Save shoyer/cbac2cf8c8675b2f3a45e4837e3bed80 to your computer and use it in GitHub Desktop.
Simple JAX GMRES
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Simple JAX GMRES",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80/simple-jax-gmres.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4LEr98cuYQf",
"colab_type": "text"
},
"source": [
"# Simple JAX GMRES\n",
"\n",
"Author: shoyer@google.com\n",
"\n",
"Date: July 7, 2020\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "SR3HPqI2q8Th",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"import numpy as np\n",
"import functools\n",
"from jax import random\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.ops\n",
"import jax.scipy as jsp\n",
"from jax.tree_util import Partial\n",
"import scipy.sparse.linalg\n",
"\n",
"def _identity(x):\n",
" return x\n",
"\n",
"def _inner(v, q):\n",
" h_jk = q.conj() @ v\n",
" v = v - h_jk * q\n",
" return (v, h_jk)\n",
"\n",
"def _outer(A, M, Q, k):\n",
" q = Q[:, k]\n",
" v = A(M(q))\n",
" # TODO: maybe better to use a masked dot-product rather than scan?\n",
" v, h_col = lax.scan(_inner, v, Q.T)\n",
" v_norm = jnp.linalg.norm(v)\n",
" Q = Q.at[:, k+1].set(v / v_norm)\n",
" h_col = h_col.at[k+1].set(v_norm)\n",
" return Q, h_col\n",
"\n",
"def arnoldi_iteration(A, b, n, M=None):\n",
" # https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration\n",
" if M is None:\n",
" M = _identity\n",
" m = b.shape[0]\n",
" q = b / jnp.linalg.norm(b)\n",
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n",
" Q, h = lax.scan(functools.partial(_outer, A, M), Q, np.arange(n))\n",
" return Q, h.T\n",
"\n",
"@jax.jit\n",
"def lstsq(a, b):\n",
" return jsp.linalg.solve(a.T @ a, a.T @ b, sym_pos=True)\n",
"\n",
"def _gmres(A, b, x0, n, M):\n",
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n",
" Q, H = arnoldi_iteration(A, b, n, M)\n",
" beta = jnp.linalg.norm(b - A(x0))\n",
" e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n",
" y = lstsq(H, beta * e1)\n",
" x = x0 + M(Q[:, :-1] @ y)\n",
" return x\n",
"\n",
"def gmres(A, b, x0=None, n=5, M=None):\n",
" if x0 is None:\n",
" x0 = jnp.zeros_like(b)\n",
" if M is None:\n",
" M = _identity\n",
" return _gmres(A, b, x0, n, M)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j5REsLhwxfsc",
"colab_type": "text"
},
"source": [
"## Tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVgN9XUExrtj",
"colab_type": "text"
},
"source": [
"Verify correctness:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pcIJkLIKX7cK",
"colab_type": "code",
"colab": {}
},
"source": [
"A = random.normal(random.PRNGKey(0), (100, 100))\n",
"b = random.normal(random.PRNGKey(1), (100,))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3CEV8LejZXXC",
"colab_type": "code",
"colab": {}
},
"source": [
"np.testing.assert_allclose(\n",
" gmres(functools.partial(jnp.dot, A), b, n=20),\n",
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],\n",
" atol=1e-6,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "IL39G9FZueDG",
"colab_type": "text"
},
"source": [
"Verify we can calculate gradients through a fixed number of loops.\n",
"\n",
"(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the [adjoint rule](https://dolfin-adjoint-doc.readthedocs.io/en/latest/documentation/maths/3-gradients.html#the-adjoint-approach).)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5YbP5hrVtjYZ",
"colab_type": "code",
"colab": {}
},
"source": [
"@jax.grad\n",
"def loss(A, b):\n",
" return jnp.sum(gmres(functools.partial(jnp.dot, A), b))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BiDj1N3BuJo4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "2d1313ac-33ba-476b-cffa-7ab1ead60a5b"
},
"source": [
"loss(A, b)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[-0.00888863, -0.01108986, -0.01395201, ..., -0.01434983,\n",
" -0.00233695, 0.0087676 ],\n",
" [ 0.0068522 , 0.00968967, 0.00116034, ..., -0.0108919 ,\n",
" -0.00220353, 0.01377204],\n",
" [-0.00557137, -0.00477795, -0.01392099, ..., -0.01569235,\n",
" -0.00254974, 0.01301789],\n",
" ...,\n",
" [-0.00446858, -0.00590282, -0.00807489, ..., -0.01217442,\n",
" -0.00532267, 0.01113929],\n",
" [ 0.00431957, 0.00333034, 0.00053749, ..., 0.00552948,\n",
" 0.00076819, -0.0026694 ],\n",
" [ 0.00614916, 0.00756274, 0.00051342, ..., -0.00826862,\n",
" -0.00276195, 0.01154379]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2fzwBVr9xhJ_",
"colab_type": "text"
},
"source": [
"## Performance"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9XGa2EY3u73b",
"colab_type": "text"
},
"source": [
"Despite our naive implementation, out of the box performance beats SciPy by about 2x:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4u1SmaDKFbHz",
"colab_type": "code",
"colab": {}
},
"source": [
"@functools.partial(jax.jit, static_argnums=(2,))\n",
"def explicit_gmres(A, b, n):\n",
" return gmres(functools.partial(jnp.dot, A), b, n=n)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XEuNmYR4r5-H",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "2461be19-56dc-4972-ff99-7a9235b42163"
},
"source": [
"# CPU\n",
"%timeit explicit_gmres(A, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 1690.02 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1 loop, best of 3: 649 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VnrL74Wv1VXG",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "6f68428e-55f0-460d-d0de-9cded4430ae3"
},
"source": [
"# GPU\n",
"%timeit explicit_gmres(A, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"10 loops, best of 3: 29.9 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DeVyIRQKr9dX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "d6c86676-3677-419b-fc6b-e6875cda4970"
},
"source": [
"b2 = np.asarray(b)\n",
"A2 = np.asarray(A)\n",
"%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"1000 loops, best of 3: 1.46 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GbyGUXO9xeps",
"colab_type": "text"
},
"source": [
"We can also `vmap` it! This gives us a big speed-up on GPUs:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-rYZBcm9uM3-",
"colab_type": "code",
"colab": {}
},
"source": [
"A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EUZPgiciHABY",
"colab_type": "code",
"colab": {}
},
"source": [
"stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JA1teSvJwbIA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "0d0600d6-3db0-4dc2-dea0-bd3aeb726ad6"
},
"source": [
"# CPU\n",
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 821 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DIQxwSQj1R9I",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "cd9ff89f-9f11-4eb9-82b6-bea17d6af125"
},
"source": [
"# GPU\n",
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"10 loops, best of 3: 31.2 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Hvg7mMzIaTY",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment