Skip to content

Instantly share code, notes, and snippets.

@Joshuaalbert
Last active October 16, 2020 15:20
Show Gist options
  • Save Joshuaalbert/214f14bbdd55d413693b8b413a384cae to your computer and use it in GitHub Desktop.
Save Joshuaalbert/214f14bbdd55d413693b8b413a384cae to your computer and use it in GitHub Desktop.
Tests speed of N-D Least squares + L1 regularisation with various backends
def speed_test_jax():
import numpy as np
from jax import jit,value_and_grad, random
from jax.scipy.optimize import minimize as minimize_jax
from scipy.optimize import minimize as minimize_np
import pylab as plt
from timeit import default_timer
S = 3
t_scipy_halfjax,t_scipy_jax,t_jax,t_numpy = [],[],[],[]
N_array = [2,10,50,100,200,400]
for N in N_array:
print("Working on N={}".format(N))
A = random.normal(random.PRNGKey(0), shape=(N, N))
u = jnp.ones(N)
x0 = -2. * jnp.ones(N)
def f_prescale(x, u):
y = A @ x
dx = u - y
return jnp.sum(dx**2) + 0.1*jnp.sum(jnp.abs(x))
# Due to https://github.com/google/jax/issues/4594 we scale the loss
# so that scipy and jax linesearch perform similarly.
jac_norm = jnp.linalg.norm(value_and_grad(f_prescale)(x0, u)[1])
jac_norm_np = np.array(jac_norm)
def f(x, u):
y = A @ x
dx = u - y
return (jnp.sum(dx**2) + 0.1*jnp.sum(jnp.abs(x)))/jac_norm
def f_np(x, u):
y = A @ x
dx = u - y
return (np.sum(dx**2) + 0.1*np.sum(np.abs(x)))/jac_norm_np
print("Testing scipy+numpy")
t0 = default_timer()
args= (np.array(x0), (np.array(u),))
results_np = minimize_np(f_np, *args, method='BFGS')
for _ in range(S):
results_np = minimize_np(f_np, *args, method='BFGS')
t_numpy.append((default_timer() - t0) / S)
print("nfev",results_np.nfev, "njev", results_np.njev)
print("Time for scipy + numpy", t_numpy[-1])
print("Testing scipy + jitted function and numeric grad")
@jit
def _f(x0, u):
return f(x0, u)
_f(x0, u).block_until_ready()
t0 = default_timer()
for _ in range(S):
results_np = minimize_np(_f, x0, (u,), method='BFGS')
t_scipy_halfjax.append((default_timer() - t0) / S)
print("nfev",results_np.nfev, "njev", results_np.njev)
print("Time for scipy + jitted function and numeric grad", t_scipy_halfjax[-1])
print("Testing scipy + jitted function and grad")
@jit
def _f(x0, u):
v, g = value_and_grad(f)(x0, u)
return v, g
_f(x0, u)[1].block_until_ready()
t0 = default_timer()
for _ in range(S):
results_np = minimize_np(_f, x0, (u,), method='BFGS', jac=True)
t_scipy_jax.append((default_timer() - t0) / S)
print("nfev",results_np.nfev, "njev", results_np.njev)
print("Time for scipy + jitted function and grad", t_scipy_jax[-1])
print("Testing pure JAX implementation")
@jit
def do_minimize_jax(x0, u):
results = minimize_jax(f, x0, args=(u,),method='BFGS')
return results.x
results_jax = minimize_jax(f, x0, args=(u,),method='BFGS')
print("JAX f(optimal)",results_jax.fun,"scipy+jax f(optimal)", results_np.fun)
do_minimize_jax(x0, u).block_until_ready()
t0 = default_timer()
for _ in range(S):
do_minimize_jax(x0, u).block_until_ready()
t_jax.append((default_timer() - t0)/S)
print("nfev", results_jax.nfev, "njev", results_jax.njev)
print("Time for pure JAX implementation", t_jax[-1])
plt.figure(figsize=(8,5))
plt.plot(N_array,t_scipy_jax,label='scipy+jitted(func and grad)')
plt.plot(N_array,t_scipy_halfjax,label='scipy+jitted(func)')
plt.plot(N_array,t_jax,label='pure JAX')
plt.plot(N_array,t_numpy,label='scipy+numpy')
plt.yscale('log')
plt.legend()
plt.title("Run time of BFGS on N-D Least squares + L1 regularisation.")
plt.ylabel('Time [s]')
plt.xlabel("N")
plt.show()
if __name__ == '__main__':
speed_test_jax()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment