|
"""Versions based on 4.60 and 4.63 of https://arxiv.org/pdf/1701.00392.pdf .""" |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
def _T(x): |
|
return jnp.swapaxes(x, -1, -2) |
|
|
|
|
|
def _H(x): |
|
return jnp.conj(_T(x)) |
|
|
|
|
|
def symmetrize(x): |
|
return (x + _H(x)) / 2 |
|
|
|
|
|
def standardize_angle(w, b): |
|
if jnp.isrealobj(w): |
|
return w * jnp.sign(w[0, :]) |
|
else: |
|
# scipy does this: makes imag(b[0] @ w) = 1 |
|
assert not jnp.isrealobj(b) |
|
bw = b[0] @ w |
|
factor = bw / jnp.abs(bw) |
|
w = w / factor[None, :] |
|
sign = jnp.sign(w.real[0]) |
|
w = w * sign |
|
return w |
|
|
|
|
|
@jax.custom_jvp # jax.scipy.linalg.eigh doesn't support general problem i.e. b not None |
|
def eigh(a, b): |
|
""" |
|
Compute the solution to the symmetrized generalized eigenvalue problem. |
|
|
|
a_s @ w = b_s @ w @ np.diag(v) |
|
|
|
where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the |
|
inputs and H is the Hermitian (conjugate transpose) operator. |
|
|
|
For self-adjoint inputs the solution should be consistent with `scipy.linalg.eigh` |
|
i.e. |
|
|
|
```python |
|
v, w = eigh(a, b) |
|
v_sp, w_sp = scipy.linalg.eigh(a, b) |
|
np.testing.assert_allclose(v, v_sp) |
|
np.testing.assert_allclose(w, standardize_angle(w_sp)) |
|
``` |
|
|
|
Note this currently uses `jax.linalg.eig(jax.linalg.solve(b, a))`, which will be |
|
slow because there is no GPU implementation of `eig` and it's just a generally |
|
inefficient way of doing it. Future implementations should wrap cuda primitives. |
|
This implementation is provided primarily as a means to test `eigh_jvp_rule`. |
|
|
|
Args: |
|
a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a) |
|
b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b) |
|
|
|
Returns: |
|
v: eigenvalues of the generalized problem in ascending order. |
|
w: eigenvectors of the generalized problem, normalized such that |
|
w.H @ b @ w = I. |
|
""" |
|
a = symmetrize(a) |
|
b = symmetrize(b) |
|
b_inv_a = jax.scipy.linalg.cho_solve(jax.scipy.linalg.cho_factor(b), a) |
|
v, w = jax.jit(jax.numpy.linalg.eig, backend="cpu")(b_inv_a) |
|
v = v.real |
|
# with loops.Scope() as s: |
|
# for _ in s.cond_range(jnp.isrealobj) |
|
if jnp.isrealobj(a) and jnp.isrealobj(b): |
|
w = w.real |
|
# reorder as ascending in w |
|
order = jnp.argsort(v) |
|
v = v.take(order, axis=0) |
|
w = w.take(order, axis=1) |
|
# renormalize so v.H @ b @ H == 1 |
|
norm2 = jax.vmap(lambda wi: (wi.conj() @ b @ wi).real, in_axes=1)(w) |
|
norm = jnp.sqrt(norm2) |
|
w = w / norm |
|
w = standardize_angle(w, b) |
|
return v, w |
|
|
|
|
|
@eigh.defjvp |
|
def eigh_jvp_rule(primals, tangents): |
|
""" |
|
Derivation based on Boedekker et al. |
|
|
|
https://arxiv.org/pdf/1701.00392.pdf |
|
|
|
Note diagonal entries of Winv dW/dt != 0 as they claim. |
|
""" |
|
a, b = primals |
|
da, db = tangents |
|
if not all(jnp.isrealobj(x) for x in (a, b, da, db)): |
|
raise NotImplementedError("jvp only implemented for real inputs.") |
|
da = symmetrize(da) |
|
db = symmetrize(db) |
|
|
|
v, w = eigh(a, b) |
|
|
|
# compute only the diagonal entries |
|
dv = jax.vmap( |
|
lambda vi, wi: -wi.conj() @ db @ wi * vi + wi.conj() @ da @ wi, in_axes=(0, 1), |
|
)(v, w) |
|
|
|
dv = dv.real |
|
|
|
E = v[jnp.newaxis, :] - v[:, jnp.newaxis] |
|
|
|
# diagonal entries: compute as column then put into diagonals |
|
diags = jnp.diag(-0.5 * jax.vmap(lambda wi: wi.conj() @ db @ wi, in_axes=1)(w)) |
|
# off-diagonals: there will be NANs on the diagonal, but these aren't used |
|
off_diags = jnp.reciprocal(E) * (_H(w) @ (da @ w - db @ w * v[jnp.newaxis, :])) |
|
|
|
dw = w @ jnp.where(jnp.eye(a.shape[0], dtype=np.bool), diags, off_diags) |
|
|
|
return (v, w), (dv, dw) |