Skip to content

Instantly share code, notes, and snippets.

@jackd
Last active January 24, 2021 14:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jackd/99e012090a56637b8dd8bb037374900e to your computer and use it in GitHub Desktop.
Save jackd/99e012090a56637b8dd8bb037374900e to your computer and use it in GitHub Desktop.
Generalized eigenvalue jvp implementation in jax

After installing jax, run with:

git clone https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e
cd 99e012090a56637b8dd8bb037374900e
python dirty_test.py
from eigh_impl import symmetrize, eigh, standardize_angle
import jax.numpy as jnp
import numpy as np
import jax.test_util as jtu
import scipy.linalg
jnp.set_printoptions(3)
rng = np.random.default_rng(0)
n = 5
is_complex = False
def make_spd(x):
n = x.shape[0]
return symmetrize(x) + n * jnp.eye(n)
def get_random_square(rng, size, is_complex=True):
real = rng.uniform(size=size).astype(np.float32)
if is_complex:
return real + rng.uniform(size=size).astype(np.float32) * 1j
return real
a = make_spd(get_random_square(rng, (n, n), is_complex=is_complex))
b = make_spd(get_random_square(rng, (n, n), is_complex=is_complex))
vals, vecs = eigh(a, b)
# ensure solution satisfies the problem
np.testing.assert_allclose(a @ vecs, b @ vecs @ jnp.diag(vals), atol=1e-5)
# ensure vectors are orthogonal w.r.t b
np.testing.assert_allclose(vecs.T.conj() @ b @ vecs, jnp.eye(n), atol=1e-5, rtol=1e-5)
# ensure eigenvalues are ascending
np.testing.assert_array_less(vals[:-1], vals[1:])
jtu.check_grads(eigh, (a, b), 2, modes=["fwd"])
# ensure values consistent with scipy
vals_sp, vecs_sp = scipy.linalg.eigh(a, b)
print("scipy")
print(vecs_sp)
print("this work")
print(vecs)
np.testing.assert_allclose(vals, vals_sp, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(vecs, standardize_angle(vecs_sp, b), rtol=1e-4, atol=1e-5)
print("success")
"""Versions based on 4.60 and 4.63 of https://arxiv.org/pdf/1701.00392.pdf ."""
import jax
import jax.numpy as jnp
import numpy as np
def _T(x):
return jnp.swapaxes(x, -1, -2)
def _H(x):
return jnp.conj(_T(x))
def symmetrize(x):
return (x + _H(x)) / 2
def standardize_angle(w, b):
if jnp.isrealobj(w):
return w * jnp.sign(w[0, :])
else:
# scipy does this: makes imag(b[0] @ w) = 1
assert not jnp.isrealobj(b)
bw = b[0] @ w
factor = bw / jnp.abs(bw)
w = w / factor[None, :]
sign = jnp.sign(w.real[0])
w = w * sign
return w
@jax.custom_jvp # jax.scipy.linalg.eigh doesn't support general problem i.e. b not None
def eigh(a, b):
"""
Compute the solution to the symmetrized generalized eigenvalue problem.
a_s @ w = b_s @ w @ np.diag(v)
where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the
inputs and H is the Hermitian (conjugate transpose) operator.
For self-adjoint inputs the solution should be consistent with `scipy.linalg.eigh`
i.e.
```python
v, w = eigh(a, b)
v_sp, w_sp = scipy.linalg.eigh(a, b)
np.testing.assert_allclose(v, v_sp)
np.testing.assert_allclose(w, standardize_angle(w_sp))
```
Note this currently uses `jax.linalg.eig(jax.linalg.solve(b, a))`, which will be
slow because there is no GPU implementation of `eig` and it's just a generally
inefficient way of doing it. Future implementations should wrap cuda primitives.
This implementation is provided primarily as a means to test `eigh_jvp_rule`.
Args:
a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a)
b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b)
Returns:
v: eigenvalues of the generalized problem in ascending order.
w: eigenvectors of the generalized problem, normalized such that
w.H @ b @ w = I.
"""
a = symmetrize(a)
b = symmetrize(b)
b_inv_a = jax.scipy.linalg.cho_solve(jax.scipy.linalg.cho_factor(b), a)
v, w = jax.jit(jax.numpy.linalg.eig, backend="cpu")(b_inv_a)
v = v.real
# with loops.Scope() as s:
# for _ in s.cond_range(jnp.isrealobj)
if jnp.isrealobj(a) and jnp.isrealobj(b):
w = w.real
# reorder as ascending in w
order = jnp.argsort(v)
v = v.take(order, axis=0)
w = w.take(order, axis=1)
# renormalize so v.H @ b @ H == 1
norm2 = jax.vmap(lambda wi: (wi.conj() @ b @ wi).real, in_axes=1)(w)
norm = jnp.sqrt(norm2)
w = w / norm
w = standardize_angle(w, b)
return v, w
@eigh.defjvp
def eigh_jvp_rule(primals, tangents):
"""
Derivation based on Boedekker et al.
https://arxiv.org/pdf/1701.00392.pdf
Note diagonal entries of Winv dW/dt != 0 as they claim.
"""
a, b = primals
da, db = tangents
if not all(jnp.isrealobj(x) for x in (a, b, da, db)):
raise NotImplementedError("jvp only implemented for real inputs.")
da = symmetrize(da)
db = symmetrize(db)
v, w = eigh(a, b)
# compute only the diagonal entries
dv = jax.vmap(
lambda vi, wi: -wi.conj() @ db @ wi * vi + wi.conj() @ da @ wi, in_axes=(0, 1),
)(v, w)
dv = dv.real
E = v[jnp.newaxis, :] - v[:, jnp.newaxis]
# diagonal entries: compute as column then put into diagonals
diags = jnp.diag(-0.5 * jax.vmap(lambda wi: wi.conj() @ db @ wi, in_axes=1)(w))
# off-diagonals: there will be NANs on the diagonal, but these aren't used
off_diags = jnp.reciprocal(E) * (_H(w) @ (da @ w - db @ w * v[jnp.newaxis, :]))
dw = w @ jnp.where(jnp.eye(a.shape[0], dtype=np.bool), diags, off_diags)
return (v, w), (dv, dw)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment