Skip to content

Instantly share code, notes, and snippets.

@slinderman
Last active May 20, 2024 07:29
Show Gist options
  • Save slinderman/24552af1bdbb6cb033bfea9b2dc4ecfd to your computer and use it in GitHub Desktop.
Save slinderman/24552af1bdbb6cb033bfea9b2dc4ecfd to your computer and use it in GitHub Desktop.
A simple wrapper for scipy.optimize.minimize using JAX. UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
A collection of helper functions for optimization with JAX.
UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
import numpy as onp
import scipy.optimize
from jax import grad, jit
from jax.tree_util import tree_flatten, tree_unflatten
from jax.flatten_util import ravel_pytree
from itertools import count
def minimize(fun, x0,
method=None,
args=(),
bounds=None,
constraints=(),
tol=None,
callback=None,
options=None):
"""
A simple wrapper for scipy.optimize.minimize using JAX.
Args:
fun: The objective function to be minimized, written in JAX code
so that it is automatically differentiable. It is of type,
```fun: x, *args -> float```
where `x` is a PyTree and args is a tuple of the fixed parameters needed
to completely specify the function.
x0: Initial guess represented as a JAX PyTree.
args: tuple, optional. Extra arguments passed to the objective function
and its derivative. Must consist of valid JAX types; e.g. the leaves
of the PyTree must be floats.
_The remainder of the keyword arguments are inherited from
`scipy.optimize.minimize`, and their descriptions are copied here for
convenience._
method : str or callable, optional
Type of solver. Should be one of
- 'Nelder-Mead' :ref:`(see here) <optimize.minimize-neldermead>`
- 'Powell' :ref:`(see here) <optimize.minimize-powell>`
- 'CG' :ref:`(see here) <optimize.minimize-cg>`
- 'BFGS' :ref:`(see here) <optimize.minimize-bfgs>`
- 'Newton-CG' :ref:`(see here) <optimize.minimize-newtoncg>`
- 'L-BFGS-B' :ref:`(see here) <optimize.minimize-lbfgsb>`
- 'TNC' :ref:`(see here) <optimize.minimize-tnc>`
- 'COBYLA' :ref:`(see here) <optimize.minimize-cobyla>`
- 'SLSQP' :ref:`(see here) <optimize.minimize-slsqp>`
- 'trust-constr':ref:`(see here) <optimize.minimize-trustconstr>`
- 'dogleg' :ref:`(see here) <optimize.minimize-dogleg>`
- 'trust-ncg' :ref:`(see here) <optimize.minimize-trustncg>`
- 'trust-exact' :ref:`(see here) <optimize.minimize-trustexact>`
- 'trust-krylov' :ref:`(see here) <optimize.minimize-trustkrylov>`
- custom - a callable object (added in version 0.14.0),
see below for description.
If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``,
depending if the problem has constraints or bounds.
bounds : sequence or `Bounds`, optional
Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and
trust-constr methods. There are two ways to specify the bounds:
1. Instance of `Bounds` class.
2. Sequence of ``(min, max)`` pairs for each element in `x`. None
is used to specify no bound.
Note that in order to use `bounds` you will need to manually flatten
them in the same order as your inputs `x0`.
constraints : {Constraint, dict} or List of {Constraint, dict}, optional
Constraints definition (only for COBYLA, SLSQP and trust-constr).
Constraints for 'trust-constr' are defined as a single object or a
list of objects specifying constraints to the optimization problem.
Available constraints are:
- `LinearConstraint`
- `NonlinearConstraint`
Constraints for COBYLA, SLSQP are defined as a list of dictionaries.
Each dictionary with fields:
type : str
Constraint type: 'eq' for equality, 'ineq' for inequality.
fun : callable
The function defining the constraint.
jac : callable, optional
The Jacobian of `fun` (only for SLSQP).
args : sequence, optional
Extra arguments to be passed to the function and Jacobian.
Equality constraint means that the constraint function result is to
be zero whereas inequality means that it is to be non-negative.
Note that COBYLA only supports inequality constraints.
Note that in order to use `constraints` you will need to manually flatten
them in the same order as your inputs `x0`.
tol : float, optional
Tolerance for termination. For detailed control, use solver-specific
options.
options : dict, optional
A dictionary of solver options. All methods accept the following
generic options:
maxiter : int
Maximum number of iterations to perform. Depending on the
method each iteration may use several function evaluations.
disp : bool
Set to True to print convergence messages.
For method-specific options, see :func:`show_options()`.
callback : callable, optional
Called after each iteration. For 'trust-constr' it is a callable with
the signature:
``callback(xk, OptimizeResult state) -> bool``
where ``xk`` is the current parameter vector represented as a PyTree,
and ``state`` is an `OptimizeResult` object, with the same fields
as the ones from the return. If callback returns True the algorithm
execution is terminated.
For all the other methods, the signature is:
```callback(xk)```
where `xk` is the current parameter vector, represented as a PyTree.
Returns:
res : The optimization result represented as a ``OptimizeResult`` object.
Important attributes are:
``x``: the solution array, represented as a JAX PyTree
``success``: a Boolean flag indicating if the optimizer exited successfully
``message``: describes the cause of the termination.
See `scipy.optimize.OptimizeResult` for a description of other attributes.
"""
# Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays
x0_flat, unravel = ravel_pytree(x0)
# Wrap the objective function to consume flat _original_
# numpy arrays and produce scalar outputs.
def fun_wrapper(x_flat, *args):
x = unravel(x_flat)
return float(fun(x, *args))
# Wrap the gradient in a similar manner
jac = jit(grad(fun))
def jac_wrapper(x_flat, *args):
x = unravel(x_flat)
g_flat, _ = ravel_pytree(jac(x, *args))
return onp.array(g_flat)
# Wrap the callback to consume a pytree
def callback_wrapper(x_flat, *args):
if callback is not None:
x = unravel(x_flat)
return callback(x, *args)
# Minimize with scipy
results = scipy.optimize.minimize(fun_wrapper,
x0_flat,
args=args,
method=method,
jac=jac_wrapper,
callback=callback_wrapper,
bounds=bounds,
constraints=constraints,
tol=tol,
options=options)
# pack the output back into a PyTree
results["x"] = unravel(results["x"])
return results
@murphyk
Copy link

murphyk commented May 19, 2022

See https://jaxopt.github.io/ for a new library that might be useful.

@richinex
Copy link

Hi Slinderman, Thanks for the wrapper. I would like to ask if there is a way to make the code vmappable? Currently I can use vmap on jax.scipy.optimize.minimize. However the downside is that it only supports the BFGS algorithm. Also, the scipy minimize wrapper in jaxopt is not vmappable. When I run the code below, I get the jax conversion error.

def do_minimize(p, x, y, z, lb, ub, smf):
    return minimize(cost_fun, p, args = (x, y, z, lb, ub, smf) , method = 'TNC', tol=1e-12,  options = {'maxiter':20000})

sol = jax.vmap(do_minimize)(par_log, F, Y, sigma_Y, lb_mat, ub_mat, smoothing_mat)

    154     # Minimize with scipy
# --> 155     results = scipy.optimize.minimize(fun_wrapper, 
#     156                                       x0_flat,
#     157                                       args=args,

# ~/anaconda3/envs/simulation/lib/python3.10/site-packages/scipy/optimize/_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
#     494 
# ...
#               4.3169071 , 4.3169071 ],
#              [2.67476726, 2.67476726, 2.67476726, ..., 4.3169071 ,
#               4.3169071 , 4.3169071 ]], dtype=float64)
#   batch_dim = 0
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

@richinex
Copy link

I just realized that using 'list(map(func, *args)' instead of vmap works well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment