Skip to content

Instantly share code, notes, and snippets.

@packquickly
Created October 11, 2023 14:06
Show Gist options
  • Save packquickly/b6f473ea5555d4fb62b1dc0c4c10cf70 to your computer and use it in GitHub Desktop.
Save packquickly/b6f473ea5555d4fb62b1dc0c4c10cf70 to your computer and use it in GitHub Desktop.
The damped trust region `search`
import functools as ft
from collections.abc import Callable
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
import optimistix as optx
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, Bool, Inexact, PyTree, Scalar, ScalarLike
from optimistix import FunctionInfo
def _default_floating_dtype():
if jax.config.jax_enable_x64: # pyright: ignore
return jnp.float64
else:
return jnp.float32
def _tree_dot(tree1: PyTree[Array], tree2: PyTree[Array]) -> Inexact[Array, ""]:
"""Compute the dot product of two PyTrees with the same PyTree structure"""
leaves1, treedef1 = jtu.tree_flatten(tree1)
leaves2, treedef2 = jtu.tree_flatten(tree2)
if treedef1 != treedef2:
raise ValueError("Tree must have the same PyTree structure.")
assert len(leaves1) == len(leaves2)
dots = []
for leaf1, leaf2 in zip(leaves1, leaves2):
dots.append(
jnp.dot(
jnp.reshape(leaf1, -1),
jnp.conj(leaf2).reshape(-1),
precision=jax.lax.Precision.HIGHEST, # pyright: ignore
)
)
if len(dots) == 0:
return jnp.array(0, _default_floating_dtype())
else:
return ft.reduce(jnp.add, dots)
def _sum_squares(x: PyTree[ArrayLike]) -> Scalar:
return _tree_dot(x, x).real
class _DampedTrustRegionState(eqx.Module):
# The identity is in the state only to avoid reinstantiating it during
# optimisation.
step_size: Scalar
identity: lx.IdentityLinearOperator
# WARNING: This search only makes theoretical sense when used with the damped newton
# descent. If I (packquickly) were implementing it for something more serious, I would
# make it "private" (add a leading underscore) to avoid the user from mistakenly using
# it in an unsupported way.
# NOTE: this is taken and only slightly modified from `_AbstractTrustRegion` in
# Optimistix.
class DampedTrustRegion(optx.AbstractSearch):
"""The damped trust-region update algorithm.
Trust region line searches compute the ratio
`true_reduction/predicted_reduction`, where `true_reduction` is the decrease in `fn`
between `y` and `new_y`, and `predicted_reduction` is how much we expected the
function to decrease using an approximation to `fn`.
The trust-region ratio determines whether to accept or reject a step and the
next choice of step-size. Specifically:
- reject the step and decrease stepsize if the ratio is smaller than a
cutoff `low_cutoff`
- accept the step and increase the step-size if the ratio is greater than
another cutoff `high_cutoff` with `low_cutoff < high_cutoff`.
- else, accept the step and make no change to the step-size.
This is different than a classical trust region update because the reduction
`predicted_reduction` uses a damped model function.
"""
high_cutoff: ScalarLike = 0.99
low_cutoff: ScalarLike = 0.01
high_constant: ScalarLike = 3.5
low_constant: ScalarLike = 0.25
def __post_init__(self):
# You would not expect `self.low_cutoff` or `self.high_cutoff` to
# be below zero, but this is technically not incorrect so we don't
# require it.
self.low_cutoff, self.high_cutoff = eqx.error_if( # pyright: ignore
(self.low_cutoff, self.high_cutoff),
self.low_cutoff > self.high_cutoff, # pyright: ignore
"`low_cutoff` must be below `high_cutoff` in `ClassicalTrustRegion`",
)
self.low_constant = eqx.error_if( # pyright: ignore
self.low_constant,
self.low_constant < 0, # pyright: ignore
"`low_constant` must be greater than `0` in `ClassicalTrustRegion`",
)
self.high_constant = eqx.error_if( # pyright: ignore
self.high_constant,
self.high_constant < 0, # pyright: ignore
"`high_constant` must be greater than `0` in `ClassicalTrustRegion`",
)
# NOTE: this is taken and only slightly modified from the `ClassicalTrustRegion`
# implementation in Optimistix.
def predict_reduction(
self,
y_diff: PyTree[Array],
f_info: optx.FunctionInfo,
state: _DampedTrustRegionState,
) -> Scalar:
"""Compute the expected decrease in loss from taking the step `y_diff`.
The true reduction is
```
fn(y0 + y_diff) - fn(y0)
```
so if `B` is the approximation to the Hessian coming from the quasi-Newton
method at `y`, `g` is the gradient of `fn` at `y`, and `δ` is the step-size,
then the predicted reduction is
```
g^T y_diff + 1/2 y_diff^T (B + 1/δ) y_diff
```
**Arguments**:
- `y_diff`: the proposed step by the descent method.
- `f_info`: the derivative information (on the gradient and Hessian)
provided by the outer loop.
- `state`: the state of the solver.
**Returns**:
The expected decrease in loss from moving from `y0` to `y0 + y_diff`.
"""
pred = state.step_size > jnp.finfo(state.step_size.dtype).eps
safe_step_size = jnp.where(pred, state.step_size, 1)
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(state.step_size).max)
if isinstance(f_info, FunctionInfo.EvalGradHessian):
# Minimisation algorithm. Directly compute the quadratic approximation.
damped_hessian = f_info.hessian + lm_param * state.identity
return _tree_dot(
y_diff,
(f_info.grad**ω + 0.5 * damped_hessian.mv(y_diff) ** ω).ω,
)
elif isinstance(f_info, FunctionInfo.ResidualJac):
# Least-squares algorithm. So instead of considering fn (which returns the
# residuals), instead consider `0.5*fn(y)^2`, and then apply the logic as
# for minimisation.
# We get that `g = J^T f0` and `B = J^T J + dJ/dx^T J`.
# (Here, `f0 = fn(y0)` are the residuals, and `J = dfn/dy(y0)` is the
# Jacobian of the residuals wrt y.)
# Then neglect the second term in B (the usual Gauss--Newton approximation)
# and complete the square.
# We find that the predicted reduction is
# `0.5 * ((J y_diff + f0)^T (J y_diff + f0) - f0^T f0)`
# and this is what is written below.
#
# The reason we go through this hassle is because this now involves only
# a single Jacobian-vector product, rather than the three we would have to
# make by naively substituting `B = J^T J `and `g = J^T f0` into the general
# algorithm used for minimisation.
rtr = _sum_squares(f_info.residual)
jacobian_term = _sum_squares(
(f_info.jac.mv(y_diff) ** ω + f_info.residual**ω).ω
)
return 0.5 * (jacobian_term - rtr) + lm_param * _sum_squares(y_diff)
else:
raise ValueError(
"Cannot use `ClassicalTrustRegion` with this solver. This is because "
"`ClassicalTrustRegion` requires (an approximation to) the Hessian of "
"the target function, but this solver does not make any estimate of "
"that information."
)
def init(self, y: PyTree[Array], f_info_struct) -> _DampedTrustRegionState:
del f_info_struct
y_struct = jtu.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), y)
return _DampedTrustRegionState(
jnp.array(1.0), lx.IdentityLinearOperator(y_struct, y_struct)
)
def step(
self,
first_step: Bool[Array, ""],
y: PyTree[Array],
y_eval: PyTree[Array],
f_info: optx.FunctionInfo,
f_eval_info: optx.FunctionInfo,
state: _DampedTrustRegionState,
) -> tuple[Scalar, Bool[Array, ""], optx.RESULTS, _DampedTrustRegionState]:
y_diff = (y_eval**ω - y**ω).ω
predicted_reduction = self.predict_reduction(y_diff, f_info, state)
# We never actually compute the ratio
# `true_reduction/predicted_reduction`. Instead, we rewrite the conditions as
# `true_reduction < const * predicted_reduction` instead, where the inequality
# flips because we assume `predicted_reduction` is negative.
# This avoids an expensive division.
f_min = f_info.as_min()
f_min_eval = f_eval_info.as_min()
f_min_diff = f_min_eval - f_min # This number is probably negative
accept = f_min_diff <= self.low_cutoff * predicted_reduction
good = f_min_diff < self.high_cutoff * predicted_reduction
good = good & (predicted_reduction < 0)
good = good & jnp.invert(first_step)
accept = accept | first_step
mul = jnp.where(good, self.high_constant, 1)
mul = jnp.where(accept, mul, self.low_constant)
new_step_size = mul * state.step_size
new_state = _DampedTrustRegionState(
step_size=new_step_size, identity=state.identity
)
return new_step_size, accept, optx.RESULTS.successful, new_state
class DampedRatioBFGS(optx.AbstractBFGS):
rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
use_inverse: bool
descent: optx.AbstractDescent
search: optx.AbstractSearch
def __init__(
self,
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar] = optx.two_norm,
linear_solver: lx.AbstractLinearSolver = lx.QR(),
):
self.atol = atol
self.rtol = rtol
self.norm = norm
self.use_inverse = False
self.descent = optx.DampedNewtonDescent(linear_solver)
self.search = DampedTrustRegion()
class DampedRatioLM(optx.AbstractGaussNewton):
rtol: float
atol: float
norm: Callable[[PyTree], Scalar]
use_inverse: bool
descent: optx.AbstractDescent
search: optx.AbstractSearch
verbose: frozenset[str]
def __init__(
self,
rtol: float,
atol: float,
norm: Callable[[PyTree], Scalar] = optx.two_norm,
linear_solver: lx.AbstractLinearSolver = lx.QR(),
verbose: frozenset[str] = frozenset(),
):
self.atol = atol
self.rtol = rtol
self.norm = norm
self.use_inverse = False
self.descent = optx.DampedNewtonDescent(linear_solver)
self.search = DampedTrustRegion()
self.verbose = verbose
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment