Last active September 27, 2023 14:25
Comparison between JAX, Numpy and Scipy

To run the benchmark (depends on numpy, scipy, jax[cpu], ipython and matplotlib)

  1. Install poetry
  2. Clone the gist
  3. Run in the directory
poetry install
  1. Then run the benchmark
poetry run ipython3

F. Orieux (L2S, Univ. Paris-Saclay)

"""This Gist compares JAX, Numpy and Scipy. The comparison is about the computation
time (CPU only) of gradient of the criterion
$$J(x) = \|y - H x\|^2$$
where $H = F^* \Lambda F W$ is a linear operator with $F$ the Discrete Fourier
transform, $\Lambda$ a complex diagonal matrix and $W$ a diagonal matrix.
The gradient is
$$\\Nabla J (x) = 2 H^T(Hx - y)$$.
Second we can have a optimized version thank to acces to mathematical expression
$$\\Nabla J (x) = 2 W F^* |\Lambda|^2 F W x - b$$
with $b = H^T t$.
- `jax`, `numpy` and `scipy` are library used for numerical computation (DFT, sum, ...).
- With `auto`, the gradient is generated by JAX from $J(x)$ with `grad` function..
- With `naive`, the gradient is coded by hand.
- With `jit`, JAX code is jitted.
- With `brain`, the gradient is coded by hand in it's optimized version.
To run the benchmark
1. Install [poetry](
2. Clone the gist
3. Run in the directory
poetry install
4. Then run the benchmark
poetry run ipython3
5. The script save the figure in 'bench_jax_numpy_scipy.jgp'
F. Orieux (L2S, Univ. Paris-Saclay) <>
import time
import jax
import jax.numpy as jnp
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.fft
from IPython import get_ipython
from jax import grad, jit
ipython = get_ipython()"tableau-colorblind10")
#%% \
# Jax version of H, H^T, crit and grad coded by hand
def jax_H(x, w, h):
return jnp.fft.irfft2(h * jnp.fft.rfft2(x * w), s=x.shape)
def jax_Ht(x, w, h):
return w * jnp.fft.irfft2(h.conj() * jnp.fft.rfft2(x), s=x.shape)
def jax_crit(x, w, h, y):
return jnp.sum(jnp.abs(y - jax_H(x, w, h)) ** 2)
def jax_hand(x, w, h, y):
return 2 * jax_Ht(jax_H(x, w, h) - y, w, h)
def jax_hand_opt(x, w, h2, hty):
return 2 * (w * jnp.fft.irfft2(h2 * jnp.fft.rfft2(x * w), s=x.shape) - hty)
#%% \
# Numpy version of H, H^T, crit and grad coded by hand
def np_H(x, w, h):
return np.fft.irfft2(h * np.fft.rfft2(x * w), s=x.shape)
def np_Ht(x, w, h):
return w * np.fft.irfft2(h.conj() * np.fft.rfft2(x), s=x.shape)
def np_hand(x, w, h, y):
return 2 * np_Ht(np_H(x, w, h) - y, w, h)
def np_hand_opt(x, w, h, y):
return 2 * (w * np.fft.irfft2(h2 * np.fft.rfft2(x * w), s=x.shape) - hty)
#%% \
# Scipy version of H, H^T, crit and grad coded by hand
def sp_H(x, w, h):
return sp.fft.irfft2(h * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1)
def sp_Ht(x, w, h):
return w * sp.fft.irfft2(
h.conj() * sp.fft.rfft2(x, workers=-1), s=x.shape, workers=-1
def sp_hand(x, w, h, y):
return 2 * sp_Ht(sp_H(x, w, h) - y, w, h)
def sp_hand_opt(x, w, h2, hty):
return 2 * (
w * sp.fft.irfft2(h2 * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1)
- hty
#%% \
shape = (512, 512)
x_np = np.ones(shape, dtype=np.float32)
y_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32)
w_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32)
h_np = np.asarray(np.fft.rfft2(np.ones((5, 5)), shape)).astype(np.complex64)
h2_np = np.abs(h_np) ** 2
hty_np = sp_Ht(x_np, w_np, h_np)
x = jnp.array(x_np)
y = jnp.array(y_np)
w = jnp.array(w_np)
h = jnp.array(h_np)
h2 = jnp.abs(h) ** 2
hty = jax_Ht(x, w, h)
#%% Auto. diff
jax_auto = grad(jax_crit)
"Auto allclose to Hand ? -> ",
np.allclose(jax_auto(x, w, h, y), jax_hand(x, w, h, y), rtol=1e-3, atol=1e-6),
#%% Jit
jit_jax_hand = jit(jax_hand)
jit_jax_hand_opt = jit(jax_hand_opt)
jit_jax_auto = jit(jax_auto)
#%% Timer
print("jit jax hand")
T_jit_jax_hand = ipython.run_line_magic("timeit", "-o jit_jax_hand(x, w, h, y)")
print("jax hand")
T_jax_hand = ipython.run_line_magic("timeit", "-o jax_hand(x, w, h, y)")
print("jit jax auto")
T_jit_jax_auto = ipython.run_line_magic("timeit", "-o jit_jax_auto(x, w, h, y)")
print("jax auto")
T_jax_auto = ipython.run_line_magic("timeit", "-o jax_auto(x, w, h, y)")
print("numpy hand")
T_np_hand = ipython.run_line_magic("timeit", "-o np_hand(x_np, w_np, h_np, y_np)")
print("scipy hand")
T_sp_hand = ipython.run_line_magic("timeit", "-o sp_hand(x_np, w_np, h_np, y_np)")
print("scipy hand opt")
T_sp_hand_opt = ipython.run_line_magic(
"timeit", "-o sp_hand_opt(x_np, w_np, h2_np, hty_np)"
print("numpy hand opt")
T_np_hand_opt = ipython.run_line_magic(
"timeit", "-o np_hand_opt(x_np, w_np, h2_np, hty_np)"
print("jit jax hand opt")
T_jit_jax_hand_opt = ipython.run_line_magic(
"timeit", "-o jit_jax_hand_opt(x, w, h2, hty)"
#%% \
# The figure
timer = sorted(
(T_np_hand, "numpy", "naive"),
(T_jax_auto, "jax", "auto"),
(T_jax_hand, "jax", "naive"),
(T_jit_jax_auto, "jaxjit", "auto"),
(T_jit_jax_hand, "jaxjit", "naive"),
(T_sp_hand, "scipy", "naive"),
(T_sp_hand_opt, "scipy", "brain"),
(T_np_hand_opt, "numpy", "brain"),
(T_jit_jax_hand_opt, "jaxjit", "brain"),
key=lambda i: i[0].average,
colorcode = {
"numpy": "+",
"jax": "/",
"jaxjit": "//",
"scipy": "*",
"naive": "C1",
"auto": "C2",
"brain": "C3",
average = [(t.average, lib, code) for (t, lib, code) in timer]
std = [(t.stdev, lib, code) for (t, lib, code) in timer]
fastest = average[-1]
text_above = [f"x {t[0] / fastest[0]:.1f}" for t in average]
fig = plt.figure(1, figsize=(1.6 * 8, 8))
ax = fig.add_subplot(111)
bar =
[it[1] + " " + it[2] for it in average],
[it[0] for it in average],
hatch=[colorcode[it[1]] for it in average],
color=[colorcode[it[2]] for it in average],
yerr=[it[0] for it in std],
ax.bar_label(bar, text_above)
"CPU Time [s] (lower is better, average on 100 runs) for evaluation of\n"
+ r"$\nabla J(x) = 2 H^T (H \cdot x - y)$ "
+ "\n when \n"
+ r"$H = F^* \Lambda F W x$ ($F$ is DFT, $\Lambda$ complex diag and $W$ diag)."
+ f"\n Version: JAX {jax.__version__}, Numpy {np.__version__}, Scipy {sp.__version__}"
+ "\n F. Orieux (L2S)"
# plt.text(
# 0.15,
# 0.75,
# "Various test\n"
# "- jax, numpy and scipy are library to compute the Fourier transform\n"
# "- with 'auto', the gradient is determined by JAX from "
# + r"$J(x) = ||H x - y||^2$ "
# + "\n"
# "- with 'hand', the gradient is determined by hand\n"
# "- with 'jit', JAX code is jitted\n"
# "- with 'opt', we use the mathematical expression to avoid unecessary FFT\n"
# "\n"
# transform=ax.transAxes,
# fontdict={"size": 14},
# bbox=dict(
# boxstyle="round",
# facecolor="white",
# # ec=(1.0, 0.5, 0.5),
# # fc=(1.0, 0.8, 0.8),
# ),
# )
# for rect, accel in zip(bar, text_above):
# height = rect.get_height()
# plt.text(
# rect.get_x() + rect.get_width() / 2.0,
# height,
# accel,
# ha="center",
# va="top",
# color="white",
# fontdict={"size": 14, "weight": "bold"},
# )
leglib1 = mpatches.Patch(facecolor=colorcode["naive"], label="Naive")
leglib2 = mpatches.Patch(facecolor=colorcode["auto"], label="Autodiff")
leglib3 = mpatches.Patch(facecolor=colorcode["brain"], label="Brain")
legcode1 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["numpy"], label="Numpy"
legcode2 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["scipy"], label="Scipy"
legcode3 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["jax"], label="JAX"
legcode4 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["jaxjit"], label="JAX + jit"
handles=[legcode1, legcode2, legcode3, legcode4, leglib1, leglib2, leglib3],
loc="upper right",
handleheight=0.7 * 4,
handlelength=2.0 * 4,
name = "jax_vs_np_vs_sp"
version = "0.1.0"
description = ""
authors = ["François Orieux <>"]
license = "Public Domaine"
python = ">=3.8,<3.11"
numpy = "^1.22.3"
scipy = "^1.8.0"
jax = {extras = ["cpu"], version = "^0.3.1"}
ipython = "^8.1.1"
matplotlib = "^3.5.1"
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
