Skip to content

Instantly share code, notes, and snippets.

@romanodev
Created July 12, 2020 19:19
Show Gist options
  • Save romanodev/e3f6bd23c499cd8a5f26b26c140abcac to your computer and use it in GitHub Desktop.
Save romanodev/e3f6bd23c499cd8a5f26b26c140abcac to your computer and use it in GitHub Desktop.
JAX GMRES Loops
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX GMRES Loops",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/romanodev/e3f6bd23c499cd8a5f26b26c140abcac/jax-gmres-loops.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4LEr98cuYQf",
"colab_type": "text"
},
"source": [
"# Simple JAX GMRES\n",
"\n",
"Readapted from: https://gist.github.com/shoyer/dc33a5850337b6a87d48ed97b4727d29 (Author: shoyer@google.com)\n",
"\n",
"Date: July 12, 2020\n",
"\n",
"Modificationas by: romanog@mit.edu July 12, 2020\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "SR3HPqI2q8Th",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"import numpy as np\n",
"import functools\n",
"from jax import random\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.ops\n",
"import jax.scipy as jsp\n",
"from jax.tree_util import Partial\n",
"import scipy.sparse.linalg\n",
"from jax.experimental import loops\n",
"\n",
"\n",
"def _identity(x):\n",
" return x\n",
"\n",
"_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST)\n",
"\n",
"\n",
"@jax.jit\n",
"def lstsq(a, b):\n",
" \n",
"\n",
" return jax.numpy.linalg.lstsq(a,b)\n",
" \n",
"\n",
"def _gmres2(A, b, x0, n, M):\n",
" \n",
" beta_e1 = jnp.linalg.norm(b - A(x0))*jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n",
"\n",
" m = b.shape[0]\n",
" q = b / jnp.linalg.norm(b)\n",
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n",
"\n",
" x0 = jnp.zeros_like(b)\n",
" H = jnp.zeros((n,n+1))\n",
" \n",
" with loops.Scope() as s:\n",
"\n",
" s.x = jnp.zeros_like(b)\n",
"\n",
" for k in range(n):\n",
"\n",
" q = Q[:, k]\n",
" v = A(M(q))\n",
" h = _dot(Q.T.conj(), v)\n",
"\n",
" v = v - _dot(Q, h)\n",
" v_norm = jnp.linalg.norm(v)\n",
" Q = Q.at[:, k+1].set(v / v_norm)\n",
" h = h.at[k+1].set(v_norm)\n",
" \n",
" H = jax.ops.index_update(H,jax.ops.index[k,:], h)\n",
" \n",
" y,r,_,_ = lstsq(H[:k+1,:].T, beta_e1)\n",
" #y = lstsq(H[:k+1,:].T, beta_e1)\n",
" \n",
" print('Residual: ', r)\n",
" \n",
" s.x = x0 + M(_dot(Q[:,:k+1], y))\n",
" \n",
" return s.x\n",
" \n",
"\n",
"def gmres(A, b, x0=None, n=5, M=None):\n",
" if x0 is None:\n",
" x0 = jnp.zeros_like(b)\n",
" if M is None:\n",
"\n",
" M = _identity\n",
" return _gmres2(A, b, x0, n, M)"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j5REsLhwxfsc",
"colab_type": "text"
},
"source": [
"## Tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVgN9XUExrtj",
"colab_type": "text"
},
"source": [
"Verify correctness:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pcIJkLIKX7cK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "15e90b8e-de7f-45ee-d885-8627583977c5"
},
"source": [
"A = random.normal(random.PRNGKey(0), (100, 100))\n",
"b = random.normal(random.PRNGKey(1), (100,))"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dtJZvMbtq9nH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 550
},
"outputId": "fbce2b54-d026-484e-f48f-0a60100cbb40"
},
"source": [
"#M = jnp.diag(1/jnp.diag(A))\n",
"M = jnp.eye(100) #no preconditioning\n",
"np.testing.assert_allclose(\n",
" gmres(functools.partial(jnp.dot, A), b, n=30,M=functools.partial(jnp.dot,M)),\n",
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=30, maxiter=1,M=np.asarray(M))[0],\n",
" atol=1e-5,rtol=1e-4,\n",
")"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"Residual: [95.708786]\n",
"Residual: [95.251]\n",
"Residual: [95.07619]\n",
"Residual: [94.94508]\n",
"Residual: [91.979004]\n",
"Residual: [91.72117]\n",
"Residual: [88.91462]\n",
"Residual: [88.69267]\n",
"Residual: [88.68685]\n",
"Residual: [87.163666]\n",
"Residual: [86.24794]\n",
"Residual: [81.21187]\n",
"Residual: [80.53536]\n",
"Residual: [80.3178]\n",
"Residual: [77.78765]\n",
"Residual: [77.60629]\n",
"Residual: [77.5413]\n",
"Residual: [76.535385]\n",
"Residual: [76.41232]\n",
"Residual: [75.28232]\n",
"Residual: [75.28172]\n",
"Residual: [74.414116]\n",
"Residual: [74.37475]\n",
"Residual: [74.05194]\n",
"Residual: [73.5671]\n",
"Residual: [73.31386]\n",
"Residual: [72.15972]\n",
"Residual: [70.36666]\n",
"Residual: [70.325455]\n",
"Residual: [69.448524]\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment