Created
November 3, 2022 14:05
-
-
Save PhilipVinc/34e15d4fc93ee58753c24b29a8fba56e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax._src.scipy.sparse.linalg import (_vdot_real_tree, _identity, _normalize_matvec, _shapes, _sub, _add, _mul, _vdot_tree) | |
from functools import partial | |
import operator | |
import numpy as np | |
import jax.numpy as jnp | |
from jax import device_put | |
from jax import lax | |
from jax import scipy as jsp | |
from jax.tree_util import (tree_leaves, tree_map, tree_structure, | |
tree_reduce, Partial,) | |
from jax._src import dtypes | |
from jax._src.lax import lax as lax_internal | |
from jax._src.util import safe_map as map | |
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0, | |
maxiter=None, M=None, check_symmetric=False, has_aux=False): | |
if x0 is None: | |
x0 = tree_map(jnp.zeros_like, b) | |
b, x0 = device_put((b, x0)) | |
if maxiter is None: | |
size = sum(bi.size for bi in tree_leaves(b)) | |
maxiter = 10 * size # copied from scipy | |
if M is None: | |
M = _identity | |
A = _normalize_matvec(A) | |
M = _normalize_matvec(M) | |
if tree_structure(x0) != tree_structure(b): | |
raise ValueError( | |
'x0 and b must have matching tree structure: ' | |
f'{tree_structure(x0)} vs {tree_structure(b)}') | |
if _shapes(x0) != _shapes(b): | |
raise ValueError( | |
'arrays in x0 and b must have matching shapes: ' | |
f'{_shapes(x0)} vs {_shapes(b)}') | |
isolve_solve = partial( | |
_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M) | |
# real-valued positive-definite linear operators are symmetric | |
def real_valued(x): | |
return not issubclass(x.dtype.type, np.complexfloating) | |
symmetric = all(map(real_valued, tree_leaves(b))) \ | |
if check_symmetric else False | |
return lax.custom_linear_solve( | |
A, b, solve=isolve_solve, transpose_solve=isolve_solve, | |
symmetric=symmetric, has_aux=has_aux) | |
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): | |
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg | |
bs = _vdot_real_tree(b, b) | |
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol)) | |
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method | |
def cond_fun(value): | |
_, r, gamma, _, k = value | |
rs = gamma.real if M is _identity else _vdot_real_tree(r, r) | |
return (rs > atol2) & (k < maxiter) | |
def body_fun(value): | |
x, r, gamma, p, k = value | |
Ap = A(p) | |
alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype) | |
x_ = _add(x, _mul(alpha, p)) | |
r_ = _sub(r, _mul(alpha, Ap)) | |
z_ = M(r_) | |
gamma_ = _vdot_real_tree(r_, z_).astype(dtype) | |
beta_ = gamma_ / gamma | |
p_ = _add(z_, _mul(beta_, p)) | |
return x_, r_, gamma_, p_, k + 1 | |
r0 = _sub(b, A(x0)) | |
p0 = z0 = M(r0) | |
dtype = jnp.result_type(*tree_leaves(p0)) | |
gamma0 = _vdot_real_tree(r0, z0).astype(dtype) | |
initial_value = (x0, r0, gamma0, p0, 0) | |
x_final, r, gamma, _, k = lax.while_loop(cond_fun, body_fun, initial_value) | |
# compute the final error and whever it has converged. | |
rs = gamma if M is _identity else _vdot_real_tree(r, r) | |
converged = rs <= atol2 | |
# additional info output structure | |
info = {'error': rs, 'converged':converged, 'niter': k} | |
return x_final, info | |
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): | |
"""Use Conjugate Gradient iteration to solve ``Ax = b``. | |
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to | |
numerical precision), but note that the interface is slightly different: you | |
need to supply the linear operator ``A`` as a function instead of a sparse | |
matrix or ``LinearOperator``. | |
Derivatives of ``cg`` are implemented via implicit differentiation with | |
another ``cg`` solve, rather than by differentiating *through* the solver. | |
They will be accurate only if both solves converge. | |
Parameters | |
---------- | |
A: ndarray, function, or matmul-compatible object | |
2D array or function that calculates the linear map (matrix-vector | |
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` must represent | |
a hermitian, positive definite matrix, and must return array(s) with the | |
same structure and shape as its argument. | |
b : array or tree of arrays | |
Right hand side of the linear system representing a single vector. Can be | |
stored as an array or Python container of array(s) with any shape. | |
Returns | |
------- | |
x : array or tree of arrays | |
The converged solution. Has the same structure as ``b``. | |
info : None | |
Placeholder for convergence information. In the future, JAX will report | |
the number of iterations when convergence is not achieved, like SciPy. | |
Other Parameters | |
---------------- | |
x0 : array or tree of arrays | |
Starting guess for the solution. Must have the same structure as ``b``. | |
tol, atol : float, optional | |
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. | |
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will | |
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. | |
maxiter : integer | |
Maximum number of iterations. Iteration will stop after maxiter | |
steps even if the specified tolerance has not been achieved. | |
M : ndarray, function, or matmul-compatible object | |
Preconditioner for A. The preconditioner should approximate the | |
inverse of A. Effective preconditioning dramatically improves the | |
rate of convergence, which implies that fewer iterations are needed | |
to reach a given error tolerance. | |
See also | |
-------- | |
scipy.sparse.linalg.cg | |
jax.lax.custom_linear_solve | |
""" | |
return _isolve(_cg_solve, | |
A=A, b=b, x0=x0, tol=tol, atol=atol, | |
maxiter=maxiter, M=M, check_symmetric=True, has_aux=True) | |
def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): | |
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab | |
bs = _vdot_real_tree(b, b) | |
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol)) | |
# https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB | |
def cond_fun(value): | |
x, r, *_, k = value | |
rs = _vdot_real_tree(r, r) | |
# the last condition checks breakdown | |
return (rs > atol2) & (k < maxiter) & (k >= 0) | |
def body_fun(value): | |
x, r, rhat, alpha, omega, rho, p, q, k = value | |
rho_ = _vdot_tree(rhat, r) | |
beta = rho_ / rho * alpha / omega | |
p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q)))) | |
phat = M(p_) | |
q_ = A(phat) | |
alpha_ = rho_ / _vdot_tree(rhat, q_) | |
s = _sub(r, _mul(alpha_, q_)) | |
exit_early = _vdot_real_tree(s, s) < atol2 | |
shat = M(s) | |
t = A(shat) | |
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases? | |
x_ = tree_map(partial(jnp.where, exit_early), | |
_add(x, _mul(alpha_, phat)), | |
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))) | |
) | |
r_ = tree_map(partial(jnp.where, exit_early), | |
s, _sub(s, _mul(omega_, t))) | |
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1) | |
k_ = jnp.where((rho_ == 0), -10, k_) | |
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_ | |
r0 = _sub(b, A(x0)) | |
rho0 = alpha0 = omega0 = lax_internal._convert_element_type( | |
1, *dtypes._lattice_result_type(*tree_leaves(b))) | |
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0) | |
x_final, r, *_, k = lax.while_loop(cond_fun, body_fun, initial_value) | |
# compute the final error and whever it has converged. | |
rs = _vdot_real_tree(r, r) | |
converged = rs <= atol2 | |
# additional info output structure | |
info = {'error': rs, 'converged':converged, 'niter': k} | |
return x_final, info | |
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): | |
"""Use Bi-Conjugate Gradient Stable iteration to solve ``Ax = b``. | |
The numerics of JAX's ``bicgstab`` should exact match SciPy's | |
``bicgstab`` (up to numerical precision), but note that the interface | |
is slightly different: you need to supply the linear operator ``A`` as | |
a function instead of a sparse matrix or ``LinearOperator``. | |
As with ``cg``, derivatives of ``bicgstab`` are implemented via implicit | |
differentiation with another ``bicgstab`` solve, rather than by | |
differentiating *through* the solver. They will be accurate only if | |
both solves converge. | |
Parameters | |
---------- | |
A: ndarray, function, or matmul-compatible object | |
2D array or function that calculates the linear map (matrix-vector | |
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` can represent | |
any general (nonsymmetric) linear operator, and function must return array(s) | |
with the same structure and shape as its argument. | |
b : array or tree of arrays | |
Right hand side of the linear system representing a single vector. Can be | |
stored as an array or Python container of array(s) with any shape. | |
Returns | |
------- | |
x : array or tree of arrays | |
The converged solution. Has the same structure as ``b``. | |
info : None | |
Placeholder for convergence information. In the future, JAX will report | |
the number of iterations when convergence is not achieved, like SciPy. | |
Other Parameters | |
---------------- | |
x0 : array or tree of arrays | |
Starting guess for the solution. Must have the same structure as ``b``. | |
tol, atol : float, optional | |
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. | |
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will | |
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. | |
maxiter : integer | |
Maximum number of iterations. Iteration will stop after maxiter | |
steps even if the specified tolerance has not been achieved. | |
M : ndarray, function, or matmul-compatible object | |
Preconditioner for A. The preconditioner should approximate the | |
inverse of A. Effective preconditioning dramatically improves the | |
rate of convergence, which implies that fewer iterations are needed | |
to reach a given error tolerance. | |
See also | |
-------- | |
scipy.sparse.linalg.bicgstab | |
jax.lax.custom_linear_solve | |
""" | |
return _isolve(_bicgstab_solve, | |
A=A, b=b, x0=x0, tol=tol, atol=atol, | |
maxiter=maxiter, M=M, has_aux=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment