Created
July 7, 2020 21:11
-
-
Save shoyer/cbac2cf8c8675b2f3a45e4837e3bed80 to your computer and use it in GitHub Desktop.
Simple JAX GMRES
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Simple JAX GMRES", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/cbac2cf8c8675b2f3a45e4837e3bed80/simple-jax-gmres.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "q4LEr98cuYQf", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Simple JAX GMRES\n", | |
"\n", | |
"Author: shoyer@google.com\n", | |
"\n", | |
"Date: July 7, 2020\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SR3HPqI2q8Th", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Copyright 2020 Google LLC.\n", | |
"# SPDX-License-Identifier: Apache-2.0\n", | |
"import numpy as np\n", | |
"import functools\n", | |
"from jax import random\n", | |
"from jax import lax\n", | |
"import jax.numpy as jnp\n", | |
"import jax.ops\n", | |
"import jax.scipy as jsp\n", | |
"from jax.tree_util import Partial\n", | |
"import scipy.sparse.linalg\n", | |
"\n", | |
"def _identity(x):\n", | |
" return x\n", | |
"\n", | |
"def _inner(v, q):\n", | |
" h_jk = q.conj() @ v\n", | |
" v = v - h_jk * q\n", | |
" return (v, h_jk)\n", | |
"\n", | |
"def _outer(A, M, Q, k):\n", | |
" q = Q[:, k]\n", | |
" v = A(M(q))\n", | |
" # TODO: maybe better to use a masked dot-product rather than scan?\n", | |
" v, h_col = lax.scan(_inner, v, Q.T)\n", | |
" v_norm = jnp.linalg.norm(v)\n", | |
" Q = Q.at[:, k+1].set(v / v_norm)\n", | |
" h_col = h_col.at[k+1].set(v_norm)\n", | |
" return Q, h_col\n", | |
"\n", | |
"def arnoldi_iteration(A, b, n, M=None):\n", | |
" # https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration\n", | |
" if M is None:\n", | |
" M = _identity\n", | |
" m = b.shape[0]\n", | |
" q = b / jnp.linalg.norm(b)\n", | |
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n", | |
" Q, h = lax.scan(functools.partial(_outer, A, M), Q, np.arange(n))\n", | |
" return Q, h.T\n", | |
"\n", | |
"@jax.jit\n", | |
"def lstsq(a, b):\n", | |
" return jsp.linalg.solve(a.T @ a, a.T @ b, sym_pos=True)\n", | |
"\n", | |
"def _gmres(A, b, x0, n, M):\n", | |
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n", | |
" Q, H = arnoldi_iteration(A, b, n, M)\n", | |
" beta = jnp.linalg.norm(b - A(x0))\n", | |
" e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n", | |
" y = lstsq(H, beta * e1)\n", | |
" x = x0 + M(Q[:, :-1] @ y)\n", | |
" return x\n", | |
"\n", | |
"def gmres(A, b, x0=None, n=5, M=None):\n", | |
" if x0 is None:\n", | |
" x0 = jnp.zeros_like(b)\n", | |
" if M is None:\n", | |
" M = _identity\n", | |
" return _gmres(A, b, x0, n, M)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "j5REsLhwxfsc", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Tests" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SVgN9XUExrtj", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Verify correctness:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pcIJkLIKX7cK", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"A = random.normal(random.PRNGKey(0), (100, 100))\n", | |
"b = random.normal(random.PRNGKey(1), (100,))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3CEV8LejZXXC", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"np.testing.assert_allclose(\n", | |
" gmres(functools.partial(jnp.dot, A), b, n=20),\n", | |
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],\n", | |
" atol=1e-6,\n", | |
")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IL39G9FZueDG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Verify we can calculate gradients through a fixed number of loops.\n", | |
"\n", | |
"(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the [adjoint rule](https://dolfin-adjoint-doc.readthedocs.io/en/latest/documentation/maths/3-gradients.html#the-adjoint-approach).)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5YbP5hrVtjYZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"@jax.grad\n", | |
"def loss(A, b):\n", | |
" return jnp.sum(gmres(functools.partial(jnp.dot, A), b))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BiDj1N3BuJo4", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 238 | |
}, | |
"outputId": "2d1313ac-33ba-476b-cffa-7ab1ead60a5b" | |
}, | |
"source": [ | |
"loss(A, b)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[-0.00888863, -0.01108986, -0.01395201, ..., -0.01434983,\n", | |
" -0.00233695, 0.0087676 ],\n", | |
" [ 0.0068522 , 0.00968967, 0.00116034, ..., -0.0108919 ,\n", | |
" -0.00220353, 0.01377204],\n", | |
" [-0.00557137, -0.00477795, -0.01392099, ..., -0.01569235,\n", | |
" -0.00254974, 0.01301789],\n", | |
" ...,\n", | |
" [-0.00446858, -0.00590282, -0.00807489, ..., -0.01217442,\n", | |
" -0.00532267, 0.01113929],\n", | |
" [ 0.00431957, 0.00333034, 0.00053749, ..., 0.00552948,\n", | |
" 0.00076819, -0.0026694 ],\n", | |
" [ 0.00614916, 0.00756274, 0.00051342, ..., -0.00826862,\n", | |
" -0.00276195, 0.01154379]], dtype=float32)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2fzwBVr9xhJ_", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Performance" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9XGa2EY3u73b", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Despite our naive implementation, out of the box performance beats SciPy by about 2x:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4u1SmaDKFbHz", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"@functools.partial(jax.jit, static_argnums=(2,))\n", | |
"def explicit_gmres(A, b, n):\n", | |
" return gmres(functools.partial(jnp.dot, A), b, n=n)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XEuNmYR4r5-H", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "2461be19-56dc-4972-ff99-7a9235b42163" | |
}, | |
"source": [ | |
"# CPU\n", | |
"%timeit explicit_gmres(A, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"The slowest run took 1690.02 times longer than the fastest. This could mean that an intermediate result is being cached.\n", | |
"1 loop, best of 3: 649 µs per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VnrL74Wv1VXG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "6f68428e-55f0-460d-d0de-9cded4430ae3" | |
}, | |
"source": [ | |
"# GPU\n", | |
"%timeit explicit_gmres(A, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"10 loops, best of 3: 29.9 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DeVyIRQKr9dX", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "d6c86676-3677-419b-fc6b-e6875cda4970" | |
}, | |
"source": [ | |
"b2 = np.asarray(b)\n", | |
"A2 = np.asarray(A)\n", | |
"%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 1.46 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GbyGUXO9xeps", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We can also `vmap` it! This gives us a big speed-up on GPUs:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-rYZBcm9uM3-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EUZPgiciHABY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JA1teSvJwbIA", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "0d0600d6-3db0-4dc2-dea0-bd3aeb726ad6" | |
}, | |
"source": [ | |
"# CPU\n", | |
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1 loop, best of 3: 821 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DIQxwSQj1R9I", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "cd9ff89f-9f11-4eb9-82b6-bea17d6af125" | |
}, | |
"source": [ | |
"# GPU\n", | |
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"10 loops, best of 3: 31.2 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9Hvg7mMzIaTY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment