Skip to content

Instantly share code, notes, and snippets.

View hblunck's full-sized avatar

Henning Blunck hblunck

  • Deutsche Post DHL Group
  • Bonn
View GitHub Profile
@slinderman
slinderman / jax_minimize_wrapper.py
Last active May 20, 2024 07:29
A simple wrapper for scipy.optimize.minimize using JAX. UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
A collection of helper functions for optimization with JAX.
UPDATE: This is obsolete now that `jax.scipy.optimize.minimize` is exists!
"""
import numpy as onp
import scipy.optimize
from jax import grad, jit
from jax.tree_util import tree_flatten, tree_unflatten
from jax.flatten_util import ravel_pytree