Last active September 16, 2020 19:03
Copy of benchmark_gmres.ipynb
"source": [
"import jax\n",
"import jax.config\n",
"import scipy as sp\n",
"import scipy.sparse.linalg\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import functools\n",
"plt.rcParams['figure.figsize'] = [10,10]\n",
"plt.rcParams['font.size'] = 20\n"
"source": [
"# Copyright 2020 Google LLC\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"from functools import partial\n",
"import operator\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from jax import scipy as jsp\n",
"from jax import lax, device_put\n",
"from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,\n",
" tree_reduce, Partial)\n",
"from jax.util import safe_map as map\n",
"_dot = partial(, precision=lax.Precision.HIGHEST)\n",
"_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)\n",
"# aliases for working with pytreedef _vdot_real_part(x, y):\n",
"def _vdot_real_part(x, y):\n",
" \"\"\"Vector dot-product guaranteed to have a real valued result despite\n",
" possibly complex input. Thus neglects the real-imaginary cross-terms.\n",
" The result is a real float.\n",
" \"\"\"\n",
" # all our uses of vdot() in CG are for computing an operator of the form\n",
" # z^T M z\n",
" # where M is positive definite and Hermitian, so the result is\n",
" # real valued:\n",
" #\n",
" vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)\n",
" result = vdot(x.real, y.real)\n",
" if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):\n",
" result += vdot(x.imag, y.imag)\n",
" return result.real\n",
"def _vdot_real_tree(x, y):\n",
" return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))\n",
"def _norm_tree(x):\n",
" return jnp.sqrt(_vdot_real_tree(x, x))\n",
"def _vdot_tree(x, y):\n",
" return sum(tree_leaves(tree_multimap(_vdot, x, y)))\n",
"def _mul(scalar, tree):\n",
" return tree_map(partial(operator.mul, scalar), tree)\n",
"def _div(tree, scalar):\n",
" return tree_map(partial(lambda v: v / scalar), tree)\n",
"_add = partial(tree_multimap, operator.add)\n",
"_sub = partial(tree_multimap, operator.sub)\n",
"_dot_tree = partial(tree_multimap, _dot)\n",
"def _identity(x):\n",
" return x\n",
"def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):\n",
" # tolerance handling uses the \"non-legacy\" behavior of\n",
" bs = _vdot_real_tree(b, b)\n",
" atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))\n",
" #\n",
" def cond_fun(value):\n",
" x, r, gamma, p, k = value\n",
" rs = gamma if M is _identity else _vdot_real_tree(r, r)\n",
" return (rs > atol2) & (k < maxiter)\n",
" def body_fun(value):\n",
" x, r, gamma, p, k = value\n",
" Ap = A(p)\n",
" alpha = gamma / _vdot_real_tree(p, Ap)\n",
" x_ = _add(x, _mul(alpha, p))\n",
" r_ = _sub(r, _mul(alpha, Ap))\n",
" z_ = M(r_)\n",
" gamma_ = _vdot_real_tree(r_, z_)\n",
" beta_ = gamma_ / gamma\n",
" p_ = _add(z_, _mul(beta_, p))\n",
" return x_, r_, gamma_, p_, k + 1\n",
" r0 = _sub(b, A(x0))\n",
" p0 = z0 = M(r0)\n",
" gamma0 = _vdot_real_tree(r0, z0)\n",
" initial_value = (x0, r0, gamma0, p0, 0)\n",
" x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)\n",
" return x_final\n",
"def _shapes(pytree):\n",
" return map(jnp.shape, tree_leaves(pytree))\n",
"def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):\n",
" \"\"\"Use Conjugate Gradient iteration to solve ``Ax = b``.\n",
" The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to\n",
" numerical precision), but note that the interface is slightly different: you\n",
" need to supply the linear operator ``A`` as a function instead of a sparse\n",
" matrix or ``LinearOperator``.\n",
" Derivatives of ``cg`` are implemented via implicit differentiation with\n",
" another ``cg`` solve, rather than by differentiating *through* the solver.\n",
" They will be accurate only if both solves converge.\n",
" Parameters\n",
" ----------\n",
" A : function\n",
" Function that calculates the matrix-vector product ``Ax`` when called\n",
" like ``A(x)``. ``A`` must represent a hermitian, positive definite\n",
" matrix, and must return array(s) with the same structure and shape as its\n",
" argument.\n",
" b : array or tree of arrays\n",
" Right hand side of the linear system representing a single vector. Can be\n",
" stored as an array or Python container of array(s) with any shape.\n",
" Returns\n",
" -------\n",
" x : array or tree of arrays\n",
" The converged solution. Has the same structure as ``b``.\n",
" info : None\n",
" Placeholder for convergence information. In the future, JAX will report\n",
" the number of iterations when convergence is not achieved, like SciPy.\n",
" Other Parameters\n",
" ----------------\n",
" x0 : array\n",
" Starting guess for the solution. Must have the same structure as ``b``.\n",
" tol, atol : float, optional\n",
" Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.\n",
" We do not implement SciPy's \"legacy\" behavior, so JAX's tolerance will\n",
" differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.\n",
" maxiter : integer\n",
" Maximum number of iterations. Iteration will stop after maxiter\n",
" steps even if the specified tolerance has not been achieved.\n",
" M : function\n",
" Preconditioner for A. The preconditioner should approximate the\n",
" inverse of A. Effective preconditioning dramatically improves the\n",
" rate of convergence, which implies that fewer iterations are needed\n",
" to reach a given error tolerance.\n",
" See also\n",
" --------\n",
" jax.lax.custom_linear_solve\n",
" \"\"\"\n",
" if x0 is None:\n",
" x0 = tree_map(jnp.zeros_like, b)\n",
" b, x0 = device_put((b, x0))\n",
" if maxiter is None:\n",
" size = sum(bi.size for bi in tree_leaves(b))\n",
" maxiter = 10 * size # copied from scipy\n",
" if M is None:\n",
" M = _identity\n",
" if tree_structure(x0) != tree_structure(b):\n",
" raise ValueError(\n",
" 'x0 and b must have matching tree structure: '\n",
" f'{tree_structure(x0)} vs {tree_structure(b)}')\n",
" if _shapes(x0) != _shapes(b):\n",
" raise ValueError(\n",
" 'arrays in x0 and b must have matching shapes: '\n",
" f'{_shapes(x0)} vs {_shapes(b)}')\n",
" cg_solve = partial(\n",
" _cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)\n",
" # real-valued positive-definite linear operators are symmetric\n",
" def real_valued(x):\n",
" return not issubclass(x.dtype.type, np.complexfloating)\n",
" symmetric = all(map(real_valued, tree_leaves(b)))\n",
" x = lax.custom_linear_solve(\n",
" A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)\n",
" info = None # TODO(shoyer): return the real iteration count here\n",
" return x, info\n",
"def _safe_normalize(x, return_norm=False, thresh=None):\n",
" \"\"\"\n",
" Returns the L2-normalized vector (which can be a pytree) x, and optionally\n",
" the computed norm. If the computed norm is less than the threshold `thresh`,\n",
" which by default is the machine precision of x's dtype, it will be\n",
" taken to be 0, and the normalized x to be the zero vector.\n",
" \"\"\"\n",
" norm = _norm_tree(x)\n",
" dtype = jnp.result_type(*tree_leaves(x))\n",
" if thresh is None:\n",
" thresh = jnp.finfo(norm.dtype).eps\n",
" thresh = thresh.astype(dtype).real\n",
" normalized_x, norm = lax.cond(\n",
" norm > thresh,\n",
" lambda y: (_div(y, norm), norm),\n",
" lambda y: (tree_map(jnp.zeros_like, y), jnp.zeros((), dtype=thresh.dtype)),\n",
" x,\n",
" )\n",
" if return_norm:\n",
" return normalized_x, norm\n",
" else:\n",
" return normalized_x\n",
"def _project_on_columns(A, v):\n",
" \"\"\"\n",
" Returns A.T.conj() @ v.\n",
" \"\"\"\n",
" v_proj = tree_multimap(\n",
" lambda X, y: jnp.einsum(\"...n,...->n\", X.conj(), y),\n",
" A,\n",
" v,\n",
" )\n",
" return tree_reduce(operator.add, v_proj)\n",
"def _iterative_classical_gram_schmidt(Q, x, iterations=2):\n",
" \"\"\"Orthogonalize x against the columns of Q.\"\"\"\n",
" # \"twice is enough\"\n",
" #\n",
" # This assumes that Q's leaves all have the same dimension in the last\n",
" # axis.\n",
" r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))\n",
" q = x\n",
" for _ in range(iterations):\n",
" h = _project_on_columns(Q, q)\n",
" Qh = tree_map(lambda X: _dot_tree(X, h), Q)\n",
" q = _sub(q, Qh)\n",
" r = _add(r, h)\n",
" return q, r\n",
"def kth_arnoldi_iteration(k, A, M, V, H, tol):\n",
" \"\"\"\n",
" Performs a single (the k'th) step of the Arnoldi process. Thus,\n",
" adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],\n",
" and that vectors overlaps with the existing Krylov vectors to\n",
" H[k, :]. The tolerance 'tol' sets the threshold at which an invariant\n",
" subspace is declared to have been found, in which case the new\n",
" vector is taken to be the zero vector.\n",
" \"\"\"\n",
" v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]\n",
" v = A(M(v))\n",
" v, h = _iterative_classical_gram_schmidt(V, v, iterations=1)\n",
" unit_v, v_norm = _safe_normalize(v, return_norm=True, thresh=tol)\n",
" V = tree_multimap(lambda X, y:[..., k + 1].set(y), V, unit_v)\n",
" h =[k + 1].set(v_norm)\n",
" def set_column_to_identity(args):\n",
" H, k, _ = args\n",
" col = jnp.zeros(H.shape[1], dtype=H.dtype)\n",
" col =[k].set(1.)\n",
" H =[k, :].set(col)\n",
" return H\n",
" def set_column_to_vector(args):\n",
" H, k, h = args\n",
" H =[k, :].set(h)\n",
" return H\n",
" H = lax.cond(v_norm == 0.,\n",
" set_column_to_identity,\n",
" set_column_to_vector,\n",
" (H, k, h))\n",
" #H =[k, :].set(h)\n",
" return V, H\n",
"def apply_givens_rotations(H_row, givens, k):\n",
" \"\"\"\n",
" Applies the Givens rotations stored in the vectors cs and sn to the vector\n",
" H_row. Then constructs and applies a new Givens rotation that eliminates\n",
" H_row's k'th element.\n",
" \"\"\"\n",
" # This call successively applies each of the\n",
" # Givens rotations stored in givens[:, :k] to H_col.\n",
" def apply_ith_rotation(i, H_row):\n",
" cs, sn = givens[i, :]\n",
" H_i = cs * H_row[i] - sn * H_row[i + 1]\n",
" H_ip1 = sn * H_row[i] + cs * H_row[i + 1]\n",
" H_row =[i].set(H_i)\n",
" H_row =[i + 1].set(H_ip1)\n",
" return H_row\n",
" R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)\n",
" def givens_rotation(v1, v2):\n",
" t = jnp.sqrt(v1**2 + v2**2)\n",
" cs = v1 / t\n",
" sn = -v2 / t\n",
" return cs, sn\n",
" givens_factors = givens_rotation(R_row[k], R_row[k + 1])\n",
" givens =[k, :].set(givens_factors)\n",
" cs_k, sn_k = givens_factors\n",
" R_row =[k].set(cs_k * R_row[k] - sn_k * R_row[k + 1])\n",
" R_row =[k + 1].set(0.)\n",
" return R_row, givens\n",
"def _gmres_qr(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):\n",
" \"\"\"\n",
" Implements a single restart of GMRES. The restart-dimensional Krylov subspace\n",
" K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the\n",
" projection of the true solution into this subspace is returned.\n",
" \"\"\"\n",
" #\n",
" # residual = _sub(b, A(x0))\n",
" # unit_residual, beta = _safe_normalize(residual, return_norm=True)\n",
" V = tree_map(\n",
" lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),\n",
" unit_residual,\n",
" )\n",
" dtype = jnp.result_type(*tree_leaves(b))\n",
" R = jnp.eye(restart, restart + 1, dtype=dtype) # eye to avoid constructing\n",
" # a singular matrix in case\n",
" # of early termination.\n",
" b_norm = _norm_tree(b)\n",
" givens = jnp.zeros((restart, 2), dtype=dtype)\n",
" beta_vec = jnp.zeros((restart + 1), dtype=dtype)\n",
" beta_vec =[0].set(residual_norm)\n",
" def loop_cond(carry):\n",
" k, err, _, _, _, _ = carry\n",
" return lax.cond(k < restart,\n",
" lambda x: x[0] > x[1],\n",
" lambda x: False,\n",
" (err, inner_tol))\n",
" # return k < restart and err > tol\n",
" def arnoldi_qr_step(carry):\n",
" k, residual_norm, V, R, beta_vec, givens = carry\n",
" V, H = kth_arnoldi_iteration(k, A, M, V, R, inner_tol)\n",
" R_row, givens = apply_givens_rotations(H[k, :], givens, k)\n",
" R =[k, :].set(R_row[:])\n",
" cs, sn = givens[k, :] * beta_vec[k]\n",
" beta_vec =[k].set(cs)\n",
" beta_vec =[k + 1].set(sn)\n",
" err = jnp.abs(sn) / b_norm\n",
" return k + 1, err, V, R, beta_vec, givens\n",
" carry = (0, residual_norm, V, R, beta_vec, givens)\n",
" carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)\n",
" k, residual_norm, V, R, beta_vec, _ = carry\n",
" y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])\n",
" Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)\n",
" dx = M(Vy)\n",
" x = _add(x0, dx)\n",
" residual = _sub(b, A(x))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
" return x, unit_residual, residual_norm\n",
"def _gmres_fixed(A, b, x0, unit_residual, residual_norm, inner_tol, restart, M):\n",
" \"\"\"\n",
" Implements a single restart of GMRES. The restart-dimensional Krylov subspace\n",
" K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the\n",
" projection of the true solution into this subspace is returned.\n",
" \"\"\"\n",
" #\n",
" V = tree_map(\n",
" lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),\n",
" unit_residual,\n",
" )\n",
" dtype = jnp.result_type(*tree_leaves(b))\n",
" H = jnp.eye(restart, restart + 1, dtype=dtype)\n",
" def arnoldi_for_scan(carry, k):\n",
" V, H = carry\n",
" V, H = kth_arnoldi_iteration(k, A, M, V, H, inner_tol)\n",
" return (V, H), None\n",
" (V, H), _ = lax.scan(arnoldi_for_scan, (V, H), jnp.arange(restart))\n",
" beta_vec = jnp.zeros((restart,), dtype=dtype)\n",
" beta_vec =[0].set(residual_norm) # it really is the original value\n",
" y = jsp.linalg.solve(H[:, :-1].T, beta_vec)\n",
" Vy = tree_map(lambda X: _dot(X[..., :-1], y), V)\n",
" dx = M(Vy)\n",
" x = _add(x0, dx)\n",
" residual = _sub(b, A(x))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
" return x, unit_residual, residual_norm\n",
"def _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" gmres_func):\n",
" \"\"\"\n",
" The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES\n",
" to find the projected solution within the order-``restart``\n",
" Krylov space K(A, x0, restart), using the result of the previous projection\n",
" in place of x0 each time.\n",
" \"\"\"\n",
" residual = _sub(b, A(x0))\n",
" unit_residual, residual_norm = _safe_normalize(residual, return_norm=True)\n",
" def cond_fun(value):\n",
" _, k, _, residual_norm = value\n",
" return lax.cond(k < maxiter,\n",
" lambda x: x[0] > x[1],\n",
" lambda x: False,\n",
" (residual_norm, outer_tol))\n",
" def body_fun(value):\n",
" x, k, unit_residual, residual_norm = value\n",
" x, unit_residual, residual_norm = gmres_func(A, b, x, unit_residual,\n",
" residual_norm, inner_tol,\n",
" restart, M)\n",
" return x, k + 1, unit_residual, residual_norm\n",
" initialization = (x0, 0, unit_residual, residual_norm)\n",
" x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)\n",
" # info = lax.cond(converged, lambda y: 0, lambda y: k, 0)\n",
" return x_final # , info\n",
"def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,\n",
" M=None, fixed_iterations=False):\n",
" \"\"\"\n",
" GMRES solves the linear system A x = b for x, given A and b. A is specified\n",
" as a function performing A(vi) -> vf = A @ vi, and in principle need not have\n",
" any particular special properties, such as symmetry. However, convergence\n",
" is often slow for nearly symmetric operators.\n",
" Parameters\n",
" ----------\n",
" A: function\n",
" Function that calculates the linear map (e.g. matrix-vector product)\n",
" ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same\n",
" structure and shape as its argument.\n",
" b : array or tree of arrays\n",
" Right hand side of the linear system representing a single vector. Can be\n",
" stored as an array or Python container of array(s) with any shape.\n",
" Returns\n",
" -------\n",
" x : array or tree of arrays\n",
" The converged solution. Has the same structure as ``b``.\n",
" info : None\n",
" Placeholder for convergence information. In the future, JAX will report\n",
" the number of iterations when convergence is not achieved, like SciPy.\n",
" Other Parameters\n",
" ----------------\n",
" x0 : array, optional\n",
" Starting guess for the solution. Must have the same structure as ``b``.\n",
" If this is unspecified, a (logical) vector of zeroes is used.\n",
" tol, atol : float, optional\n",
" Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.\n",
" We do not implement SciPy's \"legacy\" behavior, so JAX's tolerance will\n",
" differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.\n",
" restart : integer, optional\n",
" Size of the Krylov subspace (``number of iterations\") built between\n",
" restarts. GMRES works by approximating the true solution x as its\n",
" projection into a Krylov space of this dimension - this parameter\n",
" therefore bounds the maximum accuracy achievable from any guess\n",
" solution. Larger values increase both number of iterations and iteration\n",
" cost, but may be necessary for convergence. If fixed_iterations is\n",
" True, the algorithm terminates\n",
" early if convergence is achieved before the full subspace is built.\n",
" Default is 20.\n",
" maxiter : integer\n",
" Maximum number of iterations. If convergence has not been achieved\n",
" after projecting into the size-``restart`` Krylov space, GMRES will\n",
" try again, using the previous result as the new guess, up to this\n",
" many times. If the optimal solution within a Krylov space of the\n",
" given dimension is not converged up to the requested tolerance, these\n",
" restarts will not improve the accuracy, so care should be taken when\n",
" increasing this parameter.\n",
" M : function\n",
" Preconditioner for A. The preconditioner should approximate the\n",
" inverse of A. Effective preconditioning dramatically improves the\n",
" rate of convergence, which implies that fewer iterations are needed\n",
" to reach a given error tolerance.\n",
" fixed_iterations : bool\n",
" If True, the algorithm builds an internal Krylov subspace using a QR\n",
" based algorithm, permitting early termination of the inner `restart` loop\n",
" if convergence is reached. Apart from permitting early termination, this\n",
" reduces overhead and may improve stability. However, it may degrade\n",
" performance significantly on GPUs or TPUs, in which case this flag should\n",
" be set False.\n",
" See also\n",
" --------\n",
" scipy.sparse.linalg.gmres\n",
" jax.lax.custom_linear_solve\n",
" \"\"\"\n",
" if x0 is None:\n",
" x0 = tree_map(jnp.zeros_like, b)\n",
" if M is None:\n",
" M = _identity\n",
" try:\n",
" size = sum(bi.size for bi in tree_leaves(b))\n",
" except AttributeError:\n",
" size = len(tree_leaves(b))\n",
" if maxiter is None:\n",
" maxiter = 10 * size # copied from scipy\n",
" restart = min(restart, size)\n",
" if tree_structure(x0) != tree_structure(b):\n",
" raise ValueError(\n",
" 'x0 and b must have matching tree structure: '\n",
" f'{tree_structure(x0)} vs {tree_structure(b)}')\n",
" b, x0 = device_put((b, x0))\n",
" b_norm = _norm_tree(b)\n",
" # if b_norm == 0:\n",
" # return b, 0\n",
" outer_tol = jnp.minimum(tol * b_norm, atol)\n",
" Mb = M(b)\n",
" Mb_norm = _norm_tree(Mb)\n",
" inner_tol = Mb_norm * jnp.minimum(1.0, outer_tol / b_norm)\n",
" if fixed_iterations:\n",
" def _solve(A, b):\n",
" return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" _gmres_fixed)\n",
" else:\n",
" def _solve(A, b):\n",
" return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart, maxiter, M,\n",
" _gmres_qr)\n",
" x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)\n",
" failed = jnp.isnan(_norm_tree(x))\n",
" info = lax.cond(failed, lambda x: -1, lambda x: 0, 0)\n",
" return x, info\n"
"source": [
"restart = 2 * 32\n",
"maxiter = 1\n",
"N = 1024\n",
"def gmres_solve_fixed(A, b):\n",
" @jax.tree_util.Partial\n",
" def A_mv(x):\n",
" return A @ b\n",
" return gmres(A_mv, b, tol=0.0, atol=0.0, restart=restart, maxiter=maxiter,\n",
" fixed_iterations=True)\n",
"def gmres_solve_nonfixed(A, b):\n",
" @jax.tree_util.Partial\n",
" def A_mv(x):\n",
" return A @ b\n",
" return gmres(A_mv, b, tol=0.0, atol=0.0, restart=restart, maxiter=maxiter,\n",
" fixed_iterations=False)\n",
" \n",
"logical_size = N * N * jnp.finfo(dtype).bits * 1E-9 # Size in Gb (gigabits) \n",
"A = jnp.array(np.random.rand(N, N).astype(dtype))\n",
"b = jnp.array(np.random.rand(N).astype(dtype))\n"
"source": [
"%timeit gmres_solve_nonfixed(A, b)[0].block_until_ready()"
"source": [
"%timeit gmres_solve_fixed(A, b)[0].block_until_ready()"
