Created
October 11, 2023 14:06
-
-
Save packquickly/b6f473ea5555d4fb62b1dc0c4c10cf70 to your computer and use it in GitHub Desktop.
The damped trust region `search`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools as ft | |
from collections.abc import Callable | |
import equinox as eqx | |
import jax | |
import jax.numpy as jnp | |
import jax.tree_util as jtu | |
import lineax as lx | |
import optimistix as optx | |
from equinox.internal import ω | |
from jaxtyping import Array, ArrayLike, Bool, Inexact, PyTree, Scalar, ScalarLike | |
from optimistix import FunctionInfo | |
def _default_floating_dtype(): | |
if jax.config.jax_enable_x64: # pyright: ignore | |
return jnp.float64 | |
else: | |
return jnp.float32 | |
def _tree_dot(tree1: PyTree[Array], tree2: PyTree[Array]) -> Inexact[Array, ""]: | |
"""Compute the dot product of two PyTrees with the same PyTree structure""" | |
leaves1, treedef1 = jtu.tree_flatten(tree1) | |
leaves2, treedef2 = jtu.tree_flatten(tree2) | |
if treedef1 != treedef2: | |
raise ValueError("Tree must have the same PyTree structure.") | |
assert len(leaves1) == len(leaves2) | |
dots = [] | |
for leaf1, leaf2 in zip(leaves1, leaves2): | |
dots.append( | |
jnp.dot( | |
jnp.reshape(leaf1, -1), | |
jnp.conj(leaf2).reshape(-1), | |
precision=jax.lax.Precision.HIGHEST, # pyright: ignore | |
) | |
) | |
if len(dots) == 0: | |
return jnp.array(0, _default_floating_dtype()) | |
else: | |
return ft.reduce(jnp.add, dots) | |
def _sum_squares(x: PyTree[ArrayLike]) -> Scalar: | |
return _tree_dot(x, x).real | |
class _DampedTrustRegionState(eqx.Module): | |
# The identity is in the state only to avoid reinstantiating it during | |
# optimisation. | |
step_size: Scalar | |
identity: lx.IdentityLinearOperator | |
# WARNING: This search only makes theoretical sense when used with the damped newton | |
# descent. If I (packquickly) were implementing it for something more serious, I would | |
# make it "private" (add a leading underscore) to avoid the user from mistakenly using | |
# it in an unsupported way. | |
# NOTE: this is taken and only slightly modified from `_AbstractTrustRegion` in | |
# Optimistix. | |
class DampedTrustRegion(optx.AbstractSearch): | |
"""The damped trust-region update algorithm. | |
Trust region line searches compute the ratio | |
`true_reduction/predicted_reduction`, where `true_reduction` is the decrease in `fn` | |
between `y` and `new_y`, and `predicted_reduction` is how much we expected the | |
function to decrease using an approximation to `fn`. | |
The trust-region ratio determines whether to accept or reject a step and the | |
next choice of step-size. Specifically: | |
- reject the step and decrease stepsize if the ratio is smaller than a | |
cutoff `low_cutoff` | |
- accept the step and increase the step-size if the ratio is greater than | |
another cutoff `high_cutoff` with `low_cutoff < high_cutoff`. | |
- else, accept the step and make no change to the step-size. | |
This is different than a classical trust region update because the reduction | |
`predicted_reduction` uses a damped model function. | |
""" | |
high_cutoff: ScalarLike = 0.99 | |
low_cutoff: ScalarLike = 0.01 | |
high_constant: ScalarLike = 3.5 | |
low_constant: ScalarLike = 0.25 | |
def __post_init__(self): | |
# You would not expect `self.low_cutoff` or `self.high_cutoff` to | |
# be below zero, but this is technically not incorrect so we don't | |
# require it. | |
self.low_cutoff, self.high_cutoff = eqx.error_if( # pyright: ignore | |
(self.low_cutoff, self.high_cutoff), | |
self.low_cutoff > self.high_cutoff, # pyright: ignore | |
"`low_cutoff` must be below `high_cutoff` in `ClassicalTrustRegion`", | |
) | |
self.low_constant = eqx.error_if( # pyright: ignore | |
self.low_constant, | |
self.low_constant < 0, # pyright: ignore | |
"`low_constant` must be greater than `0` in `ClassicalTrustRegion`", | |
) | |
self.high_constant = eqx.error_if( # pyright: ignore | |
self.high_constant, | |
self.high_constant < 0, # pyright: ignore | |
"`high_constant` must be greater than `0` in `ClassicalTrustRegion`", | |
) | |
# NOTE: this is taken and only slightly modified from the `ClassicalTrustRegion` | |
# implementation in Optimistix. | |
def predict_reduction( | |
self, | |
y_diff: PyTree[Array], | |
f_info: optx.FunctionInfo, | |
state: _DampedTrustRegionState, | |
) -> Scalar: | |
"""Compute the expected decrease in loss from taking the step `y_diff`. | |
The true reduction is | |
``` | |
fn(y0 + y_diff) - fn(y0) | |
``` | |
so if `B` is the approximation to the Hessian coming from the quasi-Newton | |
method at `y`, `g` is the gradient of `fn` at `y`, and `δ` is the step-size, | |
then the predicted reduction is | |
``` | |
g^T y_diff + 1/2 y_diff^T (B + 1/δ) y_diff | |
``` | |
**Arguments**: | |
- `y_diff`: the proposed step by the descent method. | |
- `f_info`: the derivative information (on the gradient and Hessian) | |
provided by the outer loop. | |
- `state`: the state of the solver. | |
**Returns**: | |
The expected decrease in loss from moving from `y0` to `y0 + y_diff`. | |
""" | |
pred = state.step_size > jnp.finfo(state.step_size.dtype).eps | |
safe_step_size = jnp.where(pred, state.step_size, 1) | |
lm_param = jnp.where(pred, 1 / safe_step_size, jnp.finfo(state.step_size).max) | |
if isinstance(f_info, FunctionInfo.EvalGradHessian): | |
# Minimisation algorithm. Directly compute the quadratic approximation. | |
damped_hessian = f_info.hessian + lm_param * state.identity | |
return _tree_dot( | |
y_diff, | |
(f_info.grad**ω + 0.5 * damped_hessian.mv(y_diff) ** ω).ω, | |
) | |
elif isinstance(f_info, FunctionInfo.ResidualJac): | |
# Least-squares algorithm. So instead of considering fn (which returns the | |
# residuals), instead consider `0.5*fn(y)^2`, and then apply the logic as | |
# for minimisation. | |
# We get that `g = J^T f0` and `B = J^T J + dJ/dx^T J`. | |
# (Here, `f0 = fn(y0)` are the residuals, and `J = dfn/dy(y0)` is the | |
# Jacobian of the residuals wrt y.) | |
# Then neglect the second term in B (the usual Gauss--Newton approximation) | |
# and complete the square. | |
# We find that the predicted reduction is | |
# `0.5 * ((J y_diff + f0)^T (J y_diff + f0) - f0^T f0)` | |
# and this is what is written below. | |
# | |
# The reason we go through this hassle is because this now involves only | |
# a single Jacobian-vector product, rather than the three we would have to | |
# make by naively substituting `B = J^T J `and `g = J^T f0` into the general | |
# algorithm used for minimisation. | |
rtr = _sum_squares(f_info.residual) | |
jacobian_term = _sum_squares( | |
(f_info.jac.mv(y_diff) ** ω + f_info.residual**ω).ω | |
) | |
return 0.5 * (jacobian_term - rtr) + lm_param * _sum_squares(y_diff) | |
else: | |
raise ValueError( | |
"Cannot use `ClassicalTrustRegion` with this solver. This is because " | |
"`ClassicalTrustRegion` requires (an approximation to) the Hessian of " | |
"the target function, but this solver does not make any estimate of " | |
"that information." | |
) | |
def init(self, y: PyTree[Array], f_info_struct) -> _DampedTrustRegionState: | |
del f_info_struct | |
y_struct = jtu.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), y) | |
return _DampedTrustRegionState( | |
jnp.array(1.0), lx.IdentityLinearOperator(y_struct, y_struct) | |
) | |
def step( | |
self, | |
first_step: Bool[Array, ""], | |
y: PyTree[Array], | |
y_eval: PyTree[Array], | |
f_info: optx.FunctionInfo, | |
f_eval_info: optx.FunctionInfo, | |
state: _DampedTrustRegionState, | |
) -> tuple[Scalar, Bool[Array, ""], optx.RESULTS, _DampedTrustRegionState]: | |
y_diff = (y_eval**ω - y**ω).ω | |
predicted_reduction = self.predict_reduction(y_diff, f_info, state) | |
# We never actually compute the ratio | |
# `true_reduction/predicted_reduction`. Instead, we rewrite the conditions as | |
# `true_reduction < const * predicted_reduction` instead, where the inequality | |
# flips because we assume `predicted_reduction` is negative. | |
# This avoids an expensive division. | |
f_min = f_info.as_min() | |
f_min_eval = f_eval_info.as_min() | |
f_min_diff = f_min_eval - f_min # This number is probably negative | |
accept = f_min_diff <= self.low_cutoff * predicted_reduction | |
good = f_min_diff < self.high_cutoff * predicted_reduction | |
good = good & (predicted_reduction < 0) | |
good = good & jnp.invert(first_step) | |
accept = accept | first_step | |
mul = jnp.where(good, self.high_constant, 1) | |
mul = jnp.where(accept, mul, self.low_constant) | |
new_step_size = mul * state.step_size | |
new_state = _DampedTrustRegionState( | |
step_size=new_step_size, identity=state.identity | |
) | |
return new_step_size, accept, optx.RESULTS.successful, new_state | |
class DampedRatioBFGS(optx.AbstractBFGS): | |
rtol: float | |
atol: float | |
norm: Callable[[PyTree], Scalar] | |
use_inverse: bool | |
descent: optx.AbstractDescent | |
search: optx.AbstractSearch | |
def __init__( | |
self, | |
rtol: float, | |
atol: float, | |
norm: Callable[[PyTree], Scalar] = optx.two_norm, | |
linear_solver: lx.AbstractLinearSolver = lx.QR(), | |
): | |
self.atol = atol | |
self.rtol = rtol | |
self.norm = norm | |
self.use_inverse = False | |
self.descent = optx.DampedNewtonDescent(linear_solver) | |
self.search = DampedTrustRegion() | |
class DampedRatioLM(optx.AbstractGaussNewton): | |
rtol: float | |
atol: float | |
norm: Callable[[PyTree], Scalar] | |
use_inverse: bool | |
descent: optx.AbstractDescent | |
search: optx.AbstractSearch | |
verbose: frozenset[str] | |
def __init__( | |
self, | |
rtol: float, | |
atol: float, | |
norm: Callable[[PyTree], Scalar] = optx.two_norm, | |
linear_solver: lx.AbstractLinearSolver = lx.QR(), | |
verbose: frozenset[str] = frozenset(), | |
): | |
self.atol = atol | |
self.rtol = rtol | |
self.norm = norm | |
self.use_inverse = False | |
self.descent = optx.DampedNewtonDescent(linear_solver) | |
self.search = DampedTrustRegion() | |
self.verbose = verbose |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment