Skip to content

Instantly share code, notes, and snippets.

@slinderman
Last active May 20, 2024 07:29
Show Gist options
  • Save slinderman/24552af1bdbb6cb033bfea9b2dc4ecfd to your computer and use it in GitHub Desktop.
Save slinderman/24552af1bdbb6cb033bfea9b2dc4ecfd to your computer and use it in GitHub Desktop.
A simple wrapper for scipy.optimize.minimize using JAX. UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
A collection of helper functions for optimization with JAX.
UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
import numpy as onp
import scipy.optimize
from jax import grad, jit
from jax.tree_util import tree_flatten, tree_unflatten
from jax.flatten_util import ravel_pytree
from itertools import count
def minimize(fun, x0,
method=None,
args=(),
bounds=None,
constraints=(),
tol=None,
callback=None,
options=None):
"""
A simple wrapper for scipy.optimize.minimize using JAX.
Args:
fun: The objective function to be minimized, written in JAX code
so that it is automatically differentiable. It is of type,
```fun: x, *args -> float```
where `x` is a PyTree and args is a tuple of the fixed parameters needed
to completely specify the function.
x0: Initial guess represented as a JAX PyTree.
args: tuple, optional. Extra arguments passed to the objective function
and its derivative. Must consist of valid JAX types; e.g. the leaves
of the PyTree must be floats.
_The remainder of the keyword arguments are inherited from
`scipy.optimize.minimize`, and their descriptions are copied here for
convenience._
method : str or callable, optional
Type of solver. Should be one of
- 'Nelder-Mead' :ref:`(see here) <optimize.minimize-neldermead>`
- 'Powell' :ref:`(see here) <optimize.minimize-powell>`
- 'CG' :ref:`(see here) <optimize.minimize-cg>`
- 'BFGS' :ref:`(see here) <optimize.minimize-bfgs>`
- 'Newton-CG' :ref:`(see here) <optimize.minimize-newtoncg>`
- 'L-BFGS-B' :ref:`(see here) <optimize.minimize-lbfgsb>`
- 'TNC' :ref:`(see here) <optimize.minimize-tnc>`
- 'COBYLA' :ref:`(see here) <optimize.minimize-cobyla>`
- 'SLSQP' :ref:`(see here) <optimize.minimize-slsqp>`
- 'trust-constr':ref:`(see here) <optimize.minimize-trustconstr>`
- 'dogleg' :ref:`(see here) <optimize.minimize-dogleg>`
- 'trust-ncg' :ref:`(see here) <optimize.minimize-trustncg>`
- 'trust-exact' :ref:`(see here) <optimize.minimize-trustexact>`
- 'trust-krylov' :ref:`(see here) <optimize.minimize-trustkrylov>`
- custom - a callable object (added in version 0.14.0),
see below for description.
If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``,
depending if the problem has constraints or bounds.
bounds : sequence or `Bounds`, optional
Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and
trust-constr methods. There are two ways to specify the bounds:
1. Instance of `Bounds` class.
2. Sequence of ``(min, max)`` pairs for each element in `x`. None
is used to specify no bound.
Note that in order to use `bounds` you will need to manually flatten
them in the same order as your inputs `x0`.
constraints : {Constraint, dict} or List of {Constraint, dict}, optional
Constraints definition (only for COBYLA, SLSQP and trust-constr).
Constraints for 'trust-constr' are defined as a single object or a
list of objects specifying constraints to the optimization problem.
Available constraints are:
- `LinearConstraint`
- `NonlinearConstraint`
Constraints for COBYLA, SLSQP are defined as a list of dictionaries.
Each dictionary with fields:
type : str
Constraint type: 'eq' for equality, 'ineq' for inequality.
fun : callable
The function defining the constraint.
jac : callable, optional
The Jacobian of `fun` (only for SLSQP).
args : sequence, optional
Extra arguments to be passed to the function and Jacobian.
Equality constraint means that the constraint function result is to
be zero whereas inequality means that it is to be non-negative.
Note that COBYLA only supports inequality constraints.
Note that in order to use `constraints` you will need to manually flatten
them in the same order as your inputs `x0`.
tol : float, optional
Tolerance for termination. For detailed control, use solver-specific
options.
options : dict, optional
A dictionary of solver options. All methods accept the following
generic options:
maxiter : int
Maximum number of iterations to perform. Depending on the
method each iteration may use several function evaluations.
disp : bool
Set to True to print convergence messages.
For method-specific options, see :func:`show_options()`.
callback : callable, optional
Called after each iteration. For 'trust-constr' it is a callable with
the signature:
``callback(xk, OptimizeResult state) -> bool``
where ``xk`` is the current parameter vector represented as a PyTree,
and ``state`` is an `OptimizeResult` object, with the same fields
as the ones from the return. If callback returns True the algorithm
execution is terminated.
For all the other methods, the signature is:
```callback(xk)```
where `xk` is the current parameter vector, represented as a PyTree.
Returns:
res : The optimization result represented as a ``OptimizeResult`` object.
Important attributes are:
``x``: the solution array, represented as a JAX PyTree
``success``: a Boolean flag indicating if the optimizer exited successfully
``message``: describes the cause of the termination.
See `scipy.optimize.OptimizeResult` for a description of other attributes.
"""
# Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays
x0_flat, unravel = ravel_pytree(x0)
# Wrap the objective function to consume flat _original_
# numpy arrays and produce scalar outputs.
def fun_wrapper(x_flat, *args):
x = unravel(x_flat)
return float(fun(x, *args))
# Wrap the gradient in a similar manner
jac = jit(grad(fun))
def jac_wrapper(x_flat, *args):
x = unravel(x_flat)
g_flat, _ = ravel_pytree(jac(x, *args))
return onp.array(g_flat)
# Wrap the callback to consume a pytree
def callback_wrapper(x_flat, *args):
if callback is not None:
x = unravel(x_flat)
return callback(x, *args)
# Minimize with scipy
results = scipy.optimize.minimize(fun_wrapper,
x0_flat,
args=args,
method=method,
jac=jac_wrapper,
callback=callback_wrapper,
bounds=bounds,
constraints=constraints,
tol=tol,
options=options)
# pack the output back into a PyTree
results["x"] = unravel(results["x"])
return results
@hblunck
Copy link

hblunck commented Feb 18, 2021

Great! Thanks a lot! It worked out of the box for me using the trust-constr optimization method. Not having much jax experience, I failed to provide a properly formated hessian matrix along the lines of jac_wrapper. Any pointers? For now, I passed in BGFS(), which already helped with solution speed.

@notwa
Copy link

notwa commented Mar 28, 2021

heads up, L-BFGS-B (the default method for bounded problems) will fail because it expects dtype=float64, but jax defaults to float32. here's a brutish workaround for ipython:

%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu
# jax imports go here

@murphyk
Copy link

murphyk commented Aug 12, 2021

@slinderman
Copy link
Author

I don't think jax.scipy.optimize.minimize existed when I made this gist! Now that it does, I would definitely use it instead. I've updated the gist to say that.

@pabloduque0
Copy link

I think this is still useful when using any method other than BFGS since that is the only one that jax.scipy.optimize.minimize currently supports (at the time of writing this).

@ShubhankarKapoor
Copy link

It would still let you take in bounds and other options which aren't implemented in JAX yet

@murphyk
Copy link

murphyk commented May 19, 2022

See https://jaxopt.github.io/ for a new library that might be useful.

@richinex
Copy link

Hi Slinderman, Thanks for the wrapper. I would like to ask if there is a way to make the code vmappable? Currently I can use vmap on jax.scipy.optimize.minimize. However the downside is that it only supports the BFGS algorithm. Also, the scipy minimize wrapper in jaxopt is not vmappable. When I run the code below, I get the jax conversion error.

def do_minimize(p, x, y, z, lb, ub, smf):
    return minimize(cost_fun, p, args = (x, y, z, lb, ub, smf) , method = 'TNC', tol=1e-12,  options = {'maxiter':20000})

sol = jax.vmap(do_minimize)(par_log, F, Y, sigma_Y, lb_mat, ub_mat, smoothing_mat)

    154     # Minimize with scipy
# --> 155     results = scipy.optimize.minimize(fun_wrapper, 
#     156                                       x0_flat,
#     157                                       args=args,

# ~/anaconda3/envs/simulation/lib/python3.10/site-packages/scipy/optimize/_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
#     494 
# ...
#               4.3169071 , 4.3169071 ],
#              [2.67476726, 2.67476726, 2.67476726, ..., 4.3169071 ,
#               4.3169071 , 4.3169071 ]], dtype=float64)
#   batch_dim = 0
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

@richinex
Copy link

I just realized that using 'list(map(func, *args)' instead of vmap works well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment