Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active September 16, 2020 19:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/2f86513877fcbf044cea964060dd8de3 to your computer and use it in GitHub Desktop.
Save shoyer/2f86513877fcbf044cea964060dd8de3 to your computer and use it in GitHub Desktop.
Copy of benchmark_gmres.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Copy of benchmark_gmres.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/2f86513877fcbf044cea964060dd8de3/copy-of-benchmark_gmres.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Zn3e7byjR2dT",
"colab_type": "code",
"colab": {}
},
"source": [
"import jax\n",
"import jax.config\n",
"import scipy as sp\n",
"import scipy.sparse.linalg\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import functools\n",
"plt.rcParams['figure.figsize'] = [10,10]\n",
"plt.rcParams['font.size'] = 20\n"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CzKy-whoFunP",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"from functools import partial\n",
"import operator\n",
"\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import scipy as jsp\n",
"from jax import lax, device_put\n",
"from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,\n",
" tree_reduce, Partial)\n",
"from jax.util import safe_map as map\n",
"\n",
"\n",
"_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)\n",
"_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)\n",
"\n",
"# aliases for working with pytreedef _vdot_real_part(x, y):\n",
"\n",
"def _vdot_real_part(x, y):\n",
" \"\"\"Vector dot-product guaranteed to have a real valued result despite\n",
" possibly complex input. Thus neglects the real-imaginary cross-terms.\n",
" The result is a real float.\n",
" \"\"\"\n",
" # all our uses of vdot() in CG are for computing an operator of the form\n",
" # z^T M z\n",
" # where M is positive definite and Hermitian, so the result is\n",
" # real valued:\n",
" # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices\n",
" vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)\n",
" result = vdot(x.real, y.real)\n",
" if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):\n",
" result += vdot(x.imag, y.imag)\n",
" return result.real\n",
"\n",
"def _vdot_real_tree(x, y):\n",
" return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))\n",
"\n",
"def _norm_tree(x):\n",
" return jnp.sqrt(_vdot_real_tree(x, x))\n",
"\n",
"def _vdot_tree(x, y):\n",
" return sum(tree_leaves(tree_multimap(_vdot, x, y)))\n",
"\n",
"\n",
"def _mul(scalar, tree):\n",
" return tree_map(partial(operator.mul, scalar), tree)\n",
"\n",
"\n",
"def _div(tree, scalar):\n",
" return tree_map(partial(lambda v: v / scalar), tree)\n",
"\n",
"\n",
"_add = partial(tree_multimap, operator.add)\n",
"_sub = partial(tree_multimap, operator.sub)\n",
"_dot_tree = partial(tree_multimap, _dot)\n",
"\n",
"\n",
"@Partial\n",
"def _identity(x):\n",
" return x\n",
"\n",
"\n",
"def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):\n",
"\n",
" # tolerance handling uses the \"non-legacy\" behavior of scipy.sparse.linalg.cg\n",
" bs = _vdot_real_tree(b, b)\n",
" atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))\n",
"\n",
" # https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method\n",
"\n",
" def cond_fun(value):\n",
" x, r, gamma, p, k = value\n",
" rs = gamma if M is _identity else _vdot_real_tree(r, r)\n",
" return (rs > atol2) & (k < maxiter)\n",
"\n",
" def body_fun(value):\n",
" x, r, gamma, p, k = value\n",
" Ap = A(p)\n",
" alpha = gamma / _vdot_real_tree(p, Ap)\n",
" x_ = _add(x, _mul(alpha, p))\n",
" r_ = _sub(r, _mul(alpha, Ap))\n",
" z_ = M(r_)\n",
" gamma_ = _vdot_real_tree(r_, z_)\n",
" beta_ = gamma_ / gamma\n",
" p_ = _add(z_, _mul(beta_, p))\n",
" return x_, r_, gamma_, p_, k + 1\n",
"\n",
" r0 = _sub(b, A(x0))\n",
" p0 = z0 = M(r0)\n",
" gamma0 = _vdot_real_tree(r0, z0)\n",
" initial_value = (x0, r0, gamma0, p0, 0)\n",
"\n",
" x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)\n",
"\n",
" return x_final\n",
"\n",
"\n",
"def _shapes(pytree):\n",
" return map(jnp.shape, tree_leaves(pytree))\n",
"\n",
"\n",
"def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):\n",
" \"\"\"Use Conjugate Gradient iteration to solve ``Ax = b``.\n",
"\n",
" The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to\n",
" numerical precision), but note that the interface is slightly different: you\n",
" need to supply the linear operator ``A`` as a function instead of a sparse\n",
" matrix or ``LinearOperator``.\n",
"\n",
" Derivatives of ``cg`` are implemented via implicit differentiation with\n",
" another ``cg`` solve, rather than by differentiating *through* the solver.\n",
" They will be accurate only if both solves converge.\n",
"\n",
" Parameters\n",
" ----------\n",
" A : function\n",
" Function that calculates the matrix-vector product ``Ax`` when called\n",
" like ``A(x)``. ``A`` must represent a hermitian, positive definite\n",
" matrix, and must return array(s) with the same structure and shape as its\n",
" argument.\n",
" b : array or tree of arrays\n",
" Right hand side of the linear system representing a single vector. Can be\n",
" stored as an array or Python container of array(s) with any shape.\n",
"\n",
" Returns\n",
" -------\n",
" x : array or tree of arrays\n",
" The converged solution. Has the same structure as ``b``.\n",
" info : None\n",
" Placeholder for convergence information. In the future, JAX will report\n",
" the number of iterations when convergence is not achieved, like SciPy.\n",
"\n",
" Other Parameters\n",
" ----------------\n",
" x0 : array\n",
" Starting guess for the solution. Must have the same structure as ``b``.\n",
" tol, atol : float, optional\n",
" Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.\n",
" We do not implement SciPy's \"legacy\" behavior, so JAX's tolerance will\n",
" differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.\n",
" maxiter : integer\n",
" Maximum number of iterations. Iteration will stop after maxiter\n",
" steps even if the specified tolerance has not been achieved.\n",
" M : function\n",
" Preconditioner for A. The preconditioner should approximate the\n",
" inverse of A. Effective preconditioning dramatically improves the\n",
" rate of convergence, which implies that fewer iterations are needed\n",
" to reach a given error tolerance.\n",
"\n",
" See also\n",
" --------\n",
" scipy.sparse.linalg.cg\n",
" jax.lax.custom_linear_solve\n",
" \"\"\"\n",
" if x0 is None:\n",
" x0 = tree_map(jnp.zeros_like, b)\n",
"\n",
" b, x0 = device_put((b, x0))\n",
"\n",
" if maxiter is None:\n",
" size = sum(bi.size for bi in tree_leaves(b))\n",
" maxiter = 10 * size # copied from scipy\n",
"\n",
" if M is None:\n",
" M = _identity\n",
"\n",
" if tree_structure(x0) != tree_structure(b):\n",
" raise ValueError(\n",
" 'x0 and b must have matching tree structure: '\n",
" f'{tree_structure(x0)} vs {tree_structure(b)}')\n",
"\n",
" if _shapes(x0) != _shapes(b):\n",
" raise ValueError(\n",
" 'arrays in x0 and b must have matching shapes: '\n",
" f'{_shapes(x0)} vs {_shapes(b)}')\n",
"\n",
" cg_solve = partial(\n",
" _cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)\n",
"\n",
" # real-valued positive-definite linear operators are symmetric\n",
" def real_valued(x):\n",
" return not issubclass(x.dtype.type, np.complexfloating)\n",
" symmetric = all(map(real_valued, tree_leaves(b)))\n",
" x = lax.custom_linear_solve(\n",
" A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)\n",
" info = None # TODO(shoyer): return the real iteration count here\n",
" return x, info\n",
"\n",
"\n",
"\n",
"def _safe_normalize(x, return_norm=False, thresh=None):\n",
" \"\"\"\n",
" Returns the L2-normalized vector (which can be a pytree) x, and optionally\n",
" the computed norm. If the computed norm is less than the threshold `thresh`,\n",
" which by default is the machine precision of x's dtype, it will be\n",
" taken to be 0, and the normalized x to be the zero vector.\n",
" \"\"\"\n",
" norm = _norm_tree(x)\n",
" dtype = jnp.result_type(*tree_leaves(x))\n",
" if thresh is None:\n",
" thresh = jnp.finfo(norm.dtype).eps\n",
" thresh = thresh.astype(dtype).real\n",
"\n",
" normalized_x, norm = lax.cond(\n",
" norm > thresh,\n",
" lambda y: (_div(y, norm), norm),\n",
" lambda y: (tree_map(jnp.zeros_like, y), jnp.zeros((), dtype=thresh.dtype)),\n",
" x,\n",
" )\n",
" if return_norm:\n",
" return normalized_x, norm\n",
" else:\n",
" return normalized_x\n",
"\n",
"\n",
"def _project_on_columns(A, v):\n",
" \"\"\"\n",
" Returns A.T.conj() @ v.\n",
" \"\"\"\n",
" v_proj = tree_multimap(\n",
" lambda X, y: jnp.einsum(\"...n,...->n\", X.conj(), y),\n",
" A,\n",
" v,\n",
" )\n",
" return tree_reduce(operator.add, v_proj)\n",
"\n",
"\n",
"def _iterative_classical_gram_schmidt(Q, x, iterations=2):\n",
" \"\"\"Orthogonalize x against the columns of Q.\"\"\"\n",
" # \"twice is enough\"\n",
" # http://slepc.upv.es/documentation/reports/str1.pdf\n",
"\n",
" # This assumes that Q's leaves all have the same dimension in the last\n",
" # axis.\n",
" r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))\n",
" q = x\n",
"\n",
" for _ in range(iterations):\n",
" h = _project_on_columns(Q, q)\n",
" Qh = tree_map(lambda X: _dot_tree(X, h), Q)\n",
" q = _sub(q, Qh)\n",
" r = _add(r, h)\n",
" return q, r\n",
"\n",
"\n",
"def kth_arnoldi_iteration(k, A, M, V, H, tol):\n",
" \"\"\"\n",
" Performs a single (the k'th) step of the Arnoldi process. Thus,\n",
" adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],\n",
" and that vectors overlaps with the existing Krylov vectors to\n",
" H[k, :]. The tolerance 'tol' sets the threshold at which an invariant\n",
" subspace is declared to have been found, in which case the new\n",
" vector is taken to be the zero vector.\n",
" \"\"\"\n",
"\n",
" v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]\n",
" v = A(M(v))\n",
" v, h = _iterative_classical_gram_schmidt(V, v, iterations=1)\n",
" unit_v, v_norm = _safe_normalize(v, return_norm=True, thresh=tol)\n",
" V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)\n",
" h = h.at[k + 1].set(v_norm)\n",
"\n",
" def set_column_to_identity(args):\n",
" H, k, _ = args\n",
" col = jnp.zeros(H.shape[1], dtype=H.dtype)\n",
" col = col.at[k].set(1.)\n",
" H = H.at[k, :].set(col)\n",
" return H\n",
"\n",
" def set_column_to_vector(args):\n",
" H, k, h = args\n",
" H = H.at[k, :].set(h)\n",
" return H\n",
"\n",
" H = lax.cond(v_norm == 0.,\n",
" set_column_to_identity,\n",
" set_column_to_vector,\n",
" (H, k, h))\n",
" #H = H.at[k, :].set(h)\n",
" return V, H\n",
"\n",
"\n",
"def apply_givens_rotations(H_row, givens, k):\n",
" \"\"\"\n",
" Applies the Givens rotations stored in the vectors cs and sn to the vector\n",
" H_row. Then constructs and applies a new Givens rotation that eliminates\n",
" H_row's k'th element.\n",
" \"\"\"\n",
" # This call successively applies each of the\n",
" # Givens rotations stored in givens[:, :k] to H_col.\n",
"\n",
" def apply_ith_rotation(i, H_row):\n",
" cs, sn = givens[i, :]\n",
" H_i = cs * H_row[i] - sn * H_row[i + 1]\n",
" H_ip1 = sn * H_row[i] + cs * H_row[i + 1]\n",
" H_row = H_row.at[i].set(H_i)\n",
" H_row = H_row.at[i + 1].set(H_ip1)\n",
" return H_row\n",
"\n",
" R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)\n",
"\n",
" def givens_rotation(v1, v2):\n",
" t = jnp.sqrt(v1**2 + v2**2)\n",
" cs = v1 / t\n",
" sn = -v2 / t\n",
" return cs, sn\n",
" givens_factors = givens_rotation(R_row[k], R_row[k + 1])\n",
" givens = givens.at[k, :].set(givens_factors)\n",
" cs_k, sn_k = givens_factors\n",
"\n",
" R_row = R_row.at[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])\n",
" R_row = R_row.at[k + 1].set(0.)\n",
" return R_row, givens\n",
"\n",
"\n",
"def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):\n",
" \"\"\"\n",
" Implements a single restart of GMRES. The restart-dimensional Krylov subspace\n",
" K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the\n",
" projection of the true solution into this subspace is returned.\n",
" \"\"\"\n",
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n",
" # residual = _sub(b, A(x0))\n",
" # unit_residual, beta = _safe_normalize(residual, return_norm=True)\n",
"\n",
" V = tree_map(\n",
" lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),\n",
" unit_residual,\n",
" )\n",
" dtype = jnp.result_type(*tree_leaves(b))\n",
" R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing\n",
" # a singular matrix in case\n",
" # of early termination.\n",
" b_norm = _norm_tree(b)\n",
"\n",
" givens = jnp.zeros((restart, 2), dtype=dtype)\n",
" beta_vec = jnp.zeros((restart + 1), dtype=dtype)\n",
" beta_vec = beta_vec.at[0].set(residual_norm)\n",
"\n",
" def loop_cond(carry):\n",
" k, err, _, _, _, _ = carry\n",
" return lax.cond(k < restart,\n",
" lambda x: x[0] > x[1],\n",
" lambda x: False,\n",
" (err, inner_tol))\n",
" # return k < restart and err > tol\n",
"\n",
" def arnoldi_qr_step(carry):\n",
" k, residual_norm, V, R, beta_vec, givens = carry\n",
" V, H = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)\n",
" R_row, givens = apply_givens_rotations(H[k, :], givens, k)\n",
" R = R.at[k, :].set(R_row[:])\n",
" cs, sn = givens[k, :] * beta_vec[k]\n",
" beta_vec = beta_vec.at[k].set(cs)\n",
" beta_vec = beta_vec.at[k + 1].set(sn)\n",
" err = jnp.abs(sn) / b_norm\n",
" return k + 1, err, V, R, beta_vec, givens\n",
"\n",
" carry = (0, residual_norm, V, R, beta_vec, givens)\n",
" carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)\n",
" k, residual_norm, V, R, beta_vec, _ = carry\n",
"\n",
" y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])\n",
" Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)\n",
" dx = M(Vy)\n",
"\n",
" x = _add(x0, dx)\n",
" residual = _sub(b, A(x))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
" return x, unit_residual, residual_norm\n",
"\n",
"\n",
"def _gmres_fixed(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):\n",
" \"\"\"\n",
" Implements a single restart of GMRES. The restart-dimensional Krylov subspace\n",
" K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the\n",
" projection of the true solution into this subspace is returned.\n",
" \"\"\"\n",
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n",
" V = tree_map(\n",
" lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),\n",
" unit_residual,\n",
" )\n",
" dtype = jnp.result_type(*tree_leaves(b))\n",
" H = jnp.eye(restart, restart + 1, dtype=dtype)\n",
" def arnoldi_for_scan(carry, k):\n",
" V, H = carry\n",
" V, H = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)\n",
" return (V, H), None\n",
" (V, H), _ = lax.scan(arnoldi_for_scan, (V, H), jnp.arange(restart))\n",
"\n",
" beta_vec = jnp.zeros((restart,), dtype=dtype)\n",
" beta_vec = beta_vec.at[0].set(residual_norm) # it really is the original value\n",
" y = jsp.linalg.solve(H[:, :-1].T, beta_vec)\n",
" Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)\n",
" dx = M(Vy)\n",
" x = _add(x0, dx)\n",
"\n",
" residual = _sub(b, A(x))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
" return x, unit_residual, residual_norm\n",
"\n",
"\n",
"def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" gmres_func):\n",
" \"\"\"\n",
" The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES\n",
" to find the projected solution within the order-``restart``\n",
" Krylov space K(A, x0, restart), using the result of the previous projection\n",
" in place of x0 each time.\n",
" \"\"\"\n",
" residual = _sub(b, A(x0))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
"\n",
" def cond_fun(value):\n",
" _, k, _, residual_norm = value\n",
" return lax.cond(k < maxiter,\n",
" lambda x: x[0] > x[1],\n",
" lambda x: False,\n",
" (residual_norm, outer_tol))\n",
"\n",
" def body_fun(value):\n",
" x, k, unit_residual, residual_norm = value\n",
" x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,\n",
" residual_norm, inner_tol,\n",
" restart, M)\n",
" return x, k + 1, unit_residual, residual_norm\n",
"\n",
" initialization = (x0, 0, unit_residual, residual_norm)\n",
" x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)\n",
" # info = lax.cond(converged, lambda y: 0, lambda y: k, 0)\n",
" return x_final # , info\n",
"\n",
"\n",
"def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,\n",
" M=None, fixed_iterations=False):\n",
" \"\"\"\n",
" GMRES solves the linear system A x = b for x, given A and b. A is specified\n",
" as a function performing A(vi) -> vf = A @ vi, and in principle need not have\n",
" any particular special properties, such as symmetry. However, convergence\n",
" is often slow for nearly symmetric operators.\n",
"\n",
" Parameters\n",
" ----------\n",
" A: function\n",
" Function that calculates the linear map (e.g. matrix-vector product)\n",
" ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same\n",
" structure and shape as its argument.\n",
" b : array or tree of arrays\n",
" Right hand side of the linear system representing a single vector. Can be\n",
" stored as an array or Python container of array(s) with any shape.\n",
"\n",
" Returns\n",
" -------\n",
" x : array or tree of arrays\n",
" The converged solution. Has the same structure as ``b``.\n",
" info : None\n",
" Placeholder for convergence information. In the future, JAX will report\n",
" the number of iterations when convergence is not achieved, like SciPy.\n",
"\n",
" Other Parameters\n",
" ----------------\n",
" x0 : array, optional\n",
" Starting guess for the solution. Must have the same structure as ``b``.\n",
" If this is unspecified, a (logical) vector of zeroes is used.\n",
" tol, atol : float, optional\n",
" Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.\n",
" We do not implement SciPy's \"legacy\" behavior, so JAX's tolerance will\n",
" differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.\n",
" restart : integer, optional\n",
" Size of the Krylov subspace (``number of iterations\") built between\n",
" restarts. GMRES works by approximating the true solution x as its\n",
" projection into a Krylov space of this dimension - this parameter\n",
" therefore bounds the maximum accuracy achievable from any guess\n",
" solution. Larger values increase both number of iterations and iteration\n",
" cost, but may be necessary for convergence. If fixed_iterations is\n",
" True, the algorithm terminates\n",
" early if convergence is achieved before the full subspace is built.\n",
" Default is 20.\n",
" maxiter : integer\n",
" Maximum number of iterations. If convergence has not been achieved\n",
" after projecting into the size-``restart`` Krylov space, GMRES will\n",
" try again, using the previous result as the new guess, up to this\n",
" many times. If the optimal solution within a Krylov space of the\n",
" given dimension is not converged up to the requested tolerance, these\n",
" restarts will not improve the accuracy, so care should be taken when\n",
" increasing this parameter.\n",
" M : function\n",
" Preconditioner for A. The preconditioner should approximate the\n",
" inverse of A. Effective preconditioning dramatically improves the\n",
" rate of convergence, which implies that fewer iterations are needed\n",
" to reach a given error tolerance.\n",
" fixed_iterations : bool\n",
" If True, the algorithm builds an internal Krylov subspace using a QR\n",
" based algorithm, permitting early termination of the inner `restart` loop\n",
" if convergence is reached. Apart from permitting early termination, this\n",
" reduces overhead and may improve stability. However, it may degrade\n",
" performance significantly on GPUs or TPUs, in which case this flag should\n",
" be set False.\n",
"\n",
" See also\n",
" --------\n",
" scipy.sparse.linalg.gmres\n",
" jax.lax.custom_linear_solve\n",
" \"\"\"\n",
"\n",
" if x0 is None:\n",
" x0 = tree_map(jnp.zeros_like, b)\n",
" if M is None:\n",
" M = _identity\n",
"\n",
" try:\n",
" size = sum(bi.size for bi in tree_leaves(b))\n",
" except AttributeError:\n",
" size = len(tree_leaves(b))\n",
"\n",
" if maxiter is None:\n",
" maxiter = 10 * size # copied from scipy\n",
" restart = min(restart, size)\n",
"\n",
" if tree_structure(x0) != tree_structure(b):\n",
" raise ValueError(\n",
" 'x0 and b must have matching tree structure: '\n",
" f'{tree_structure(x0)} vs {tree_structure(b)}')\n",
"\n",
" b, x0 = device_put((b, x0))\n",
" b_norm = _norm_tree(b)\n",
" # if b_norm == 0:\n",
" # return b, 0\n",
" outer_tol = jnp.minimum(tol * b_norm, atol)\n",
"\n",
" Mb = M(b)\n",
" Mb_norm = _norm_tree(Mb)\n",
" inner_tol = Mb_norm * jnp.minimum(1.0, outer_tol / b_norm)\n",
"\n",
" if fixed_iterations:\n",
" def _solve(A, b):\n",
" return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" _gmres_fixed)\n",
" else:\n",
" def _solve(A, b):\n",
" return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" _gmres_qr)\n",
"\n",
" x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)\n",
"\n",
" failed = jnp.isnan(_norm_tree(x))\n",
" info = lax.cond(failed, lambda x: -1, lambda x: 0, 0)\n",
" return x, info\n"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "udnlqvIuP5Sa",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "19d9b7e5-6d95-468f-d67b-a43dd0dd41f4"
},
"source": [
"restart = 2 * 32\n",
"maxiter = 1\n",
"N = 1024\n",
"\n",
"@jax.jit\n",
"def gmres_solve_fixed(A, b):\n",
" @jax.tree_util.Partial\n",
" def A_mv(x):\n",
" return A @ b\n",
" return gmres(A_mv, b, tol=0.0, atol=0.0, restart=restart, maxiter=maxiter,\n",
" fixed_iterations=True)\n",
"\n",
"@jax.jit\n",
"def gmres_solve_nonfixed(A, b):\n",
" @jax.tree_util.Partial\n",
" def A_mv(x):\n",
" return A @ b\n",
" return gmres(A_mv, b, tol=0.0, atol=0.0, restart=restart, maxiter=maxiter,\n",
" fixed_iterations=False)\n",
" \n",
"dtype=np.float32\n",
"logical_size = N * N * jnp.finfo(dtype).bits * 1E-9 # Size in Gb (gigabits) \n",
"print(logical_size)\n",
"A = jnp.array(np.random.rand(N, N).astype(dtype))\n",
"b = jnp.array(np.random.rand(N).astype(dtype))\n"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"0.033554432\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "v3SpsM25RI_e",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "9d757a96-971c-4515-dc5c-3d390e21cba1"
},
"source": [
"%timeit gmres_solve_nonfixed(A, b)[0].block_until_ready()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 8.32 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1 loop, best of 3: 153 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ca7GvkAhRsEO",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "2765ccce-71af-44a2-be21-7beff57e5d56"
},
"source": [
"%timeit gmres_solve_fixed(A, b)[0].block_until_ready()"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 88.39 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1 loop, best of 3: 13.2 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wlceFWcXcKh8",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment