Skip to content

Instantly share code, notes, and snippets.

Created June 4, 2021 20:10
Show Gist options
  • Save Joshuaalbert/39e5e44a06cb00e7e154b37504e30fa1 to your computer and use it in GitHub Desktop.
Save Joshuaalbert/39e5e44a06cb00e7e154b37504e30fa1 to your computer and use it in GitHub Desktop.
Regression Test BFGS speed test against jax and jaxlib versions.
def speed_test_jax():
import numpy as np
from jax import jit, value_and_grad, random, numpy as jnp
from jax.scipy.optimize import minimize as minimize_jax
from scipy.optimize import minimize as minimize_np
import pylab as plt
from timeit import default_timer
import jax
JAX_VERSION = jax.__version__
import jaxlib
JAXLIB_VERSION = jaxlib.__version__
S = 3
t_scipy_halfjax, t_scipy_jax, t_jax, t_numpy = [], [], [], []
N_array = [2, 10, 50, 100, 200, 400]
for N in N_array:
print("Working on N={}".format(N))
A = random.normal(random.PRNGKey(0), shape=(N, N))
u = jnp.ones(N)
x0 = -2. * jnp.ones(N)
def f_prescale(x, u):
y = A @ x
dx = u - y
return jnp.sum(dx ** 2) + 0.1 * jnp.sum(jnp.abs(x))
# Due to we scale the loss
# so that scipy and jax linesearch perform similarly.
jac_norm = jnp.linalg.norm(value_and_grad(f_prescale)(x0, u)[1])
jac_norm_np = np.array(jac_norm)
def f(x, u):
y = A @ x
dx = u - y
return (jnp.sum(dx ** 2) + 0.1 * jnp.sum(jnp.abs(x))) / jac_norm
def f_np(x, u):
y = A @ x
dx = u - y
return (np.sum(dx ** 2) + 0.1 * np.sum(np.abs(x))) / jac_norm_np
print("Testing scipy+numpy")
t0 = default_timer()
args = (np.array(x0), (np.array(u),))
results_np = minimize_np(f_np, *args, method='BFGS')
for _ in range(S):
results_np = minimize_np(f_np, *args, method='BFGS')
t_numpy.append((default_timer() - t0) / S)
print("nfev", results_np.nfev, "njev", results_np.njev)
print("Time for scipy + numpy", t_numpy[-1])
print("Testing scipy + jitted function and numeric grad")
def _f(x0, u):
return f(x0, u)
_f(x0, u).block_until_ready()
t0 = default_timer()
for _ in range(S):
results_np = minimize_np(_f, x0, (u,), method='BFGS')
t_scipy_halfjax.append((default_timer() - t0) / S)
print("nfev", results_np.nfev, "njev", results_np.njev)
print("Time for scipy + jitted function and numeric grad", t_scipy_halfjax[-1])
print("Testing scipy + jitted function and grad")
def _f(x0, u):
v, g = value_and_grad(f)(x0, u)
return v, g
_f(x0, u)[1].block_until_ready()
t0 = default_timer()
for _ in range(S):
results_np = minimize_np(_f, x0, (u,), method='BFGS', jac=True)
t_scipy_jax.append((default_timer() - t0) / S)
print("nfev", results_np.nfev, "njev", results_np.njev)
print("Time for scipy + jitted function and grad", t_scipy_jax[-1])
print("Testing pure JAX implementation")
def do_minimize_jax(x0, u):
results = minimize_jax(f, x0, args=(u,), method='BFGS')
return results.x
results_jax = minimize_jax(f, x0, args=(u,), method='BFGS')
print("JAX f(optimal)",, "scipy+jax f(optimal)",
do_minimize_jax(x0, u).block_until_ready()
t0 = default_timer()
for _ in range(S):
do_minimize_jax(x0, u).block_until_ready()
t_jax.append((default_timer() - t0) / S)
print("nfev", results_jax.nfev, "njev", results_jax.njev)
print("Time for pure JAX implementation", t_jax[-1])
plt.figure(figsize=(8, 5))
plt.plot(N_array, t_scipy_jax, label='scipy+jitted(func and grad)')
plt.plot(N_array, t_scipy_halfjax, label='scipy+jitted(func)')
plt.plot(N_array, t_jax, label='pure JAX')
plt.plot(N_array, t_numpy, label='scipy+numpy')
plt.title(f"(jax: {JAX_VERSION}, jaxlib: {JAXLIB_VERSION}) Run time of BFGS on N-D Least squares + L1 regularisation.")
plt.ylabel('Time [s]')
if __name__ == '__main__':
JAX_VERSIONS=(0.1.74 0.1.75 0.1.76 0.1.77 0.2.0 0.2.1 0.2.2 0.2.3 0.2.4 0.2.5 0.2.6 0.2.7 0.2.8 0.2.9 0.2.10 0.2.11 0.2.12 0.2.13)
JAXLIB_VERSIONS=(0.1.52 0.1.52 0.1.52 0.1.55 0.1.55 0.1.55 0.1.56 0.1.56 0.1.56 0.1.56 0.1.57 0.1.57 0.1.58 0.1.59 0.1.61 0.1.64 0.1.64 0.1.65)
conda create -n test_env python=3.8
conda activate test_env
pip install matplotlib scipy numpy
for index in ${!JAX_VERSIONS[@]}; do
echo Running $((index+1))/${#JAX_VERSIONS[@]} with jax=="${JAX_VERSIONS[index]}" and jaxlib=="${JAXLIB_VERSIONS[index]}"
pip install --ignore-installed jax=="${JAX_VERSIONS[index]}" jaxlib=="${JAXLIB_VERSIONS[index]}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment