Skip to content

Instantly share code, notes, and snippets.

@packquickly
Created November 1, 2023 13:06
Show Gist options
  • Save packquickly/a8848bc92e12bcfe6876ee804b48b06a to your computer and use it in GitHub Desktop.
Save packquickly/a8848bc92e12bcfe6876ee804b48b06a to your computer and use it in GitHub Desktop.
An example of splitting an optimisation problem over two solvers
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import optax
import optimistix as optx
# HIMMELBG test problem from CUTE
@eqx.filter_jit
def toy_problem(y1, y2, constants):
c1, c2 = constants
return (c1 * y1**2 + c2 * y2**2) * jnp.exp(-y1 - y2), None
# Set up both solvers
bfgs_solver = optx.BFGS(rtol=1e-3, atol=1e-3)
optax_solver = optx.OptaxMinimiser(optax.adam(learning_rate=1e-3), rtol=1e-3, atol=1e-3)
# The initial guess for the solution
y = (jnp.array(1.5), jnp.array(1.5))
# Any auxiliary information to pass to `fn`.
args = (jnp.array(2.0), jnp.array(3.0))
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
# Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy.
# (In this scale it's a scalar, so these don't matter.)
tags = frozenset()
def solve(y, bfgs_solver, optax_solver):
y1, y2 = y
# Set up the step methods for each solver. The BFGS step closes over the parameters
# it's not optimising, namely `y1`, in it's `fn`. Same goes for Optax.
bfgs_step = eqx.filter_jit(
lambda y1, state, y2: bfgs_solver.step(
fn=lambda var, args: toy_problem(var, y2, args),
y=y1,
args=args,
options={},
state=state,
tags=tags,
)
)
bfgs_terminate = eqx.filter_jit(
lambda y1, state, y2: bfgs_solver.terminate(
fn=lambda var, args: toy_problem(var, y2, args),
y=y1,
args=args,
options={},
state=state,
tags=tags,
)
)
optax_step = eqx.filter_jit(
lambda y2, state, y1: optax_solver.step(
fn=lambda var, args: toy_problem(y1, var, args),
y=y2,
args=args,
options={},
state=state,
tags=tags,
)
)
optax_terminate = eqx.filter_jit(
lambda y2, state, y1: optax_solver.terminate(
fn=lambda var, args: toy_problem(y1, var, args),
y=y2,
args=args,
options={},
state=state,
tags=tags,
)
)
# Initial state before we start solving.
bfgs_state = bfgs_solver.init(
lambda var, args: toy_problem(var, y2, args),
y1,
args,
{},
f_struct,
aux_struct,
tags,
)
optax_state = optax_solver.init(
lambda var, args: toy_problem(y1, var, args),
y2,
args,
{},
f_struct,
aux_struct,
tags,
)
bfgs_done, bfgs_result = bfgs_terminate(y1, bfgs_state, y2)
optax_done, optax_result = optax_terminate(y2, optax_state, y1)
# Alright, enough setup. Let's do the solve!
while not bfgs_done or not optax_done:
print(f"Evaluating point {y1, y2} with value {toy_problem(y1, y2, args)[0]}.")
# Don't want to accidentally pass the new value of `y1` to the
# Optax step, hence the temporary renaming of `y1_new`.
y1_new, bfgs_state, bfgs_aux = bfgs_step(y1, bfgs_state, y2)
y2, optax_state, optax_aux = optax_step(y2, optax_state, y1)
y1 = y1_new
bfgs_done, bfgs_result = bfgs_terminate(y1, bfgs_state, y2)
optax_done, optax_result = optax_terminate(y2, optax_state, y1)
if bfgs_result != optx.RESULTS.successful:
print(f"Oh no! BFGS found an error {bfgs_result}.")
if optax_result != optx.RESULTS.successful:
print(f"Oh no! Optax found an error {optax_result}.")
y1_final, _, _ = bfgs_solver.postprocess(
lambda var, args: toy_problem(var, y2, args),
y1,
bfgs_aux, # pyright: ignore
args,
{},
bfgs_state,
tags,
bfgs_result,
)
y2_final, _, _ = optax_solver.postprocess(
lambda var, args: toy_problem(y1, var, args),
y2,
optax_aux, # pyright: ignore
args,
{},
optax_state,
tags,
optax_result,
)
print(f"Found solution {(y1, y2)} with value {toy_problem(y1, y2, args)[0]}.")
solve(y, bfgs_solver, optax_solver)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment