Created
November 1, 2023 13:06
-
-
Save packquickly/a8848bc92e12bcfe6876ee804b48b06a to your computer and use it in GitHub Desktop.
An example of splitting an optimisation problem over two solvers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import equinox as eqx # https://github.com/patrick-kidger/equinox | |
import jax | |
import jax.numpy as jnp | |
import optax | |
import optimistix as optx | |
# HIMMELBG test problem from CUTE | |
@eqx.filter_jit | |
def toy_problem(y1, y2, constants): | |
c1, c2 = constants | |
return (c1 * y1**2 + c2 * y2**2) * jnp.exp(-y1 - y2), None | |
# Set up both solvers | |
bfgs_solver = optx.BFGS(rtol=1e-3, atol=1e-3) | |
optax_solver = optx.OptaxMinimiser(optax.adam(learning_rate=1e-3), rtol=1e-3, atol=1e-3) | |
# The initial guess for the solution | |
y = (jnp.array(1.5), jnp.array(1.5)) | |
# Any auxiliary information to pass to `fn`. | |
args = (jnp.array(2.0), jnp.array(3.0)) | |
f_struct = jax.ShapeDtypeStruct((), jnp.float32) | |
aux_struct = None | |
# Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy. | |
# (In this scale it's a scalar, so these don't matter.) | |
tags = frozenset() | |
def solve(y, bfgs_solver, optax_solver): | |
y1, y2 = y | |
# Set up the step methods for each solver. The BFGS step closes over the parameters | |
# it's not optimising, namely `y1`, in it's `fn`. Same goes for Optax. | |
bfgs_step = eqx.filter_jit( | |
lambda y1, state, y2: bfgs_solver.step( | |
fn=lambda var, args: toy_problem(var, y2, args), | |
y=y1, | |
args=args, | |
options={}, | |
state=state, | |
tags=tags, | |
) | |
) | |
bfgs_terminate = eqx.filter_jit( | |
lambda y1, state, y2: bfgs_solver.terminate( | |
fn=lambda var, args: toy_problem(var, y2, args), | |
y=y1, | |
args=args, | |
options={}, | |
state=state, | |
tags=tags, | |
) | |
) | |
optax_step = eqx.filter_jit( | |
lambda y2, state, y1: optax_solver.step( | |
fn=lambda var, args: toy_problem(y1, var, args), | |
y=y2, | |
args=args, | |
options={}, | |
state=state, | |
tags=tags, | |
) | |
) | |
optax_terminate = eqx.filter_jit( | |
lambda y2, state, y1: optax_solver.terminate( | |
fn=lambda var, args: toy_problem(y1, var, args), | |
y=y2, | |
args=args, | |
options={}, | |
state=state, | |
tags=tags, | |
) | |
) | |
# Initial state before we start solving. | |
bfgs_state = bfgs_solver.init( | |
lambda var, args: toy_problem(var, y2, args), | |
y1, | |
args, | |
{}, | |
f_struct, | |
aux_struct, | |
tags, | |
) | |
optax_state = optax_solver.init( | |
lambda var, args: toy_problem(y1, var, args), | |
y2, | |
args, | |
{}, | |
f_struct, | |
aux_struct, | |
tags, | |
) | |
bfgs_done, bfgs_result = bfgs_terminate(y1, bfgs_state, y2) | |
optax_done, optax_result = optax_terminate(y2, optax_state, y1) | |
# Alright, enough setup. Let's do the solve! | |
while not bfgs_done or not optax_done: | |
print(f"Evaluating point {y1, y2} with value {toy_problem(y1, y2, args)[0]}.") | |
# Don't want to accidentally pass the new value of `y1` to the | |
# Optax step, hence the temporary renaming of `y1_new`. | |
y1_new, bfgs_state, bfgs_aux = bfgs_step(y1, bfgs_state, y2) | |
y2, optax_state, optax_aux = optax_step(y2, optax_state, y1) | |
y1 = y1_new | |
bfgs_done, bfgs_result = bfgs_terminate(y1, bfgs_state, y2) | |
optax_done, optax_result = optax_terminate(y2, optax_state, y1) | |
if bfgs_result != optx.RESULTS.successful: | |
print(f"Oh no! BFGS found an error {bfgs_result}.") | |
if optax_result != optx.RESULTS.successful: | |
print(f"Oh no! Optax found an error {optax_result}.") | |
y1_final, _, _ = bfgs_solver.postprocess( | |
lambda var, args: toy_problem(var, y2, args), | |
y1, | |
bfgs_aux, # pyright: ignore | |
args, | |
{}, | |
bfgs_state, | |
tags, | |
bfgs_result, | |
) | |
y2_final, _, _ = optax_solver.postprocess( | |
lambda var, args: toy_problem(y1, var, args), | |
y2, | |
optax_aux, # pyright: ignore | |
args, | |
{}, | |
optax_state, | |
tags, | |
optax_result, | |
) | |
print(f"Found solution {(y1, y2)} with value {toy_problem(y1, y2, args)[0]}.") | |
solve(y, bfgs_solver, optax_solver) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment