JAX enteres an infinite loop for trust-ncg minimization
# %%
from functools import partial
from typing import NamedTuple, Optional, Tuple, Union
import jax
from jax import lax
from jax import numpy as jnp
N_RESET = 20
class CGResults(NamedTuple):
x: jnp.ndarray
nit: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray] # number of matrix-evaluations
info: Union[int, jnp.ndarray]
success: Union[bool, jnp.ndarray]
# The following is code adapted from Nicholas Mancuso to work with pytrees
class _QuadSubproblemResult(NamedTuple):
step: jnp.ndarray
hits_boundary: Union[bool, jnp.ndarray]
pred_f: Union[float, jnp.ndarray]
nit: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
njev: Union[int, jnp.ndarray]
nhev: Union[int, jnp.ndarray]
success: Union[bool, jnp.ndarray]
class _CGSteihaugState(NamedTuple):
z: jnp.ndarray
r: jnp.ndarray
d: jnp.ndarray
step: jnp.ndarray
energy: Union[None, float, jnp.ndarray]
hits_boundary: Union[bool, jnp.ndarray]
done: Union[bool, jnp.ndarray]
nit: Union[int, jnp.ndarray]
nhev: Union[int, jnp.ndarray]
def second_order_approx(
p: jnp.ndarray,
cur_val: Union[float, jnp.ndarray],
g: jnp.ndarray,
) -> Union[float, jnp.ndarray]:
return cur_val + jnp.vdot(g, p) + 0.5 * jnp.vdot(p, hessp_at_xk(p))
def get_boundaries_intersections(
z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray]
a = jnp.vdot(d, d)
b = 2 * jnp.vdot(z, d)
c = jnp.vdot(z, z) - trust_radius**2
sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c)
aux = b + jnp.copysign(sqrt_discriminant, b)
ta = -aux / (2 * a)
tb = -2 * c / aux
ra = jnp.where(ta < tb, ta, tb)
rb = jnp.where(ta < tb, tb, ta)
return (ra, rb)
def _cg_steihaug_subproblem(
cur_val: Union[float, jnp.ndarray],
g: jnp.ndarray,
trust_radius: Union[float, jnp.ndarray],
tr_norm_ord: Union[None, int, float, jnp.ndarray] = None,
resnorm: Optional[float],
absdelta: Optional[float] = None,
norm_ord: Union[None, int, float, jnp.ndarray] = None,
miniter: Union[None, int] = None,
maxiter: Union[None, int] = None,
) -> _QuadSubproblemResult:
from jax.experimental.host_callback import call
tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX
norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1
maxiter_fallback = 20 * g.size # taken from SciPy's NewtonCG minimzer
miniter = jnp.minimum(
6, maxiter if maxiter is not None else maxiter_fallback
) if miniter is None else miniter
maxiter = jnp.maximum(
jnp.minimum(200, maxiter_fallback), miniter
) if maxiter is None else maxiter
common_dtp = g.dtype
eps = 6. * jnp.finfo(common_dtp).eps
# second-order Taylor series approximation at the current values, gradient,
# and hessian
soa = partial(
second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk
# helpers for internal switches in the main CGSteihaug logic
def noop(
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
) -> _CGSteihaugState:
iterp, z_next = param
return iterp
def step1(
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
) -> _CGSteihaugState:
iterp, z_next = param
z, d, nhev = iterp.z, iterp.d, iterp.nhev
ta, tb = get_boundaries_intersections(z, d, trust_radius)
pa = z + ta * d
pb = z + tb * d
p_boundary = jnp.where(soa(pa) < soa(pb), pa, pb)
return iterp._replace(
step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True
def step2(
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
) -> _CGSteihaugState:
iterp, z_next = param
z, d = iterp.z, iterp.d
ta, tb = get_boundaries_intersections(z, d, trust_radius)
p_boundary = z + tb * d
return iterp._replace(step=p_boundary, hits_boundary=True, done=True)
def step3(
param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]]
) -> _CGSteihaugState:
iterp, z_next = param
return iterp._replace(step=z_next, hits_boundary=False, done=True)
# initialize the step
p_origin = jnp.zeros_like(g)
# init the state for the first iteration
z = p_origin
r = g
d = -r
energy = 0.
init_param = _CGSteihaugState(
done=maxiter == 0,
import jax
# Search for the min of the approximation of the objective function.
def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState:
z, r, d = iterp.z, iterp.r, iterp.d
energy, nit =, iterp.nit
nit += 1
jax.debug.print("in body {nit} \\\\ 1 ::", nit=nit)
Bd = hessp_at_xk(d)
dBd = jnp.vdot(d, Bd)
r_squared = jnp.vdot(r, r)
alpha = r_squared / dBd
z_next = z + alpha * d
r_next = r + alpha * Bd
r_next_squared = jnp.vdot(r_next, r_next)
beta_next = r_next_squared / r_squared
d_next = -r_next + beta_next * d
jax.debug.print("in body {nit} \\\\ 2 ::", nit=nit)
accept_z_next = nit >= maxiter
"in body {nit} \\\\ 3 :: accept_z_next={accept_z_next}",
if norm_ord == 2:
r_next_norm = jnp.sqrt(r_next_squared)
r_next_norm = jnp.linalg.norm(r_next, ord=norm_ord)
accept_z_next |= r_next_norm < resnorm
# Relative to a plain CG, `z_next` is negative
energy_next = jnp.vdot((r_next + g) / 2, z_next)
energy_diff = energy - energy_next
if absdelta is not None:
neg_energy_eps = -eps * jnp.abs(energy)
accept_z_next |= (energy_diff >= neg_energy_eps
) & (energy_diff < absdelta) & (nit >= miniter)
jax.debug.print("in body {nit} \\\\ 4 ::", nit=nit)
# include a junk switch to catch the case where none should be executed
z_next_norm = jnp.linalg.norm(z_next, ord=tr_norm_ord)
jax.debug.print("in body {nit} \\\\ 5 :: pre-index", nit=nit)
index = jnp.argmax(
[False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next]
jax.debug.print("in body {nit} \\\\ 6 :: pre-switch {index}", nit=nit, index=index)
iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next))
jax.debug.print("in body {nit} \\\\ 7 :: post-switch", nit=nit)
iterp = iterp._replace(
nhev=iterp.nhev + 1,
return iterp
def cond_f(iterp: _CGSteihaugState) -> bool:
"cond_f={c} maxiter={maxiter}", c=~iterp.done, maxiter=maxiter
return jnp.logical_not(iterp.done)
# perform inner optimization to solve the constrained
# quadratic subproblem using cg
jax.debug.print("looped {result.done} {result}", result=init_param)
result = lax.while_loop(cond_f, body_f, init_param)
jax.debug.print("looped {result.done} {result}", result=result)
pred_f = soa(result.step)
result = _QuadSubproblemResult(
nhev=result.nhev + 1,
return result
def rosenbrock(np):
def func(x):
return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2)
return func
def himmelblau(np):
def func(p):
x, y = p
return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2
return func
def matyas(np):
def func(p):
x, y = p
return 0.26 * (x**2 + y**2) - 0.48 * x * y
return func
def eggholder(np):
def func(p):
x, y = p
return -(y + 47.) * jnp.sin(
jnp.sqrt(jnp.abs(x / 2. + y + 47.))
) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.))))
return func
def hessp(primals, tangents):
return jax.jvp(jax.grad(fun), (primals, ), (tangents, ))[1]
fun = eggholder(jnp)
x0 = jnp.ones(2) * 100.
f0, g0 = jax.value_and_grad(fun)(x0)
kwargs = {
"absdelta": 0.,
"resnorm": 0.,
"trust_radius": 1.,
"norm_ord": 1,
_cg_steihaug_subproblem(f0, g0, partial(hessp, x0), **kwargs)
