Skip to content

Instantly share code, notes, and snippets.

@forieux
Last active September 27, 2023 14:25
Show Gist options
  • Save forieux/84203cd898baabd5a38f808248fb1655 to your computer and use it in GitHub Desktop.
Save forieux/84203cd898baabd5a38f808248fb1655 to your computer and use it in GitHub Desktop.
Comparison between JAX, Numpy and Scipy

This Gist compares JAX, Numpy and Scipy. The comparison is about the computation time (CPU only) of gradient of the criterion

$J(x) = |y - H x|^2$

where $H = F^H Λ F W$ is a linear operator with $F$ the Discrete Fourier transform, $Λ$ a complex diagonal matrix (then $FᴴΛF$ is circulant convolution), $W$ a diagonal matrix, $x$ and $y$ are images.

Gradients

The gradient is

$$\nabla J(x) = 2 H^T (H x - y).$$

Second we can have a optimized version thank to acces to mathematical expression

$$\nabla J(x) = 2 W F^H |Λ|^2 F W x - b$$

with $b = H^T y$.

Tests

There is two axes of experiment.

  1. The backend for the computation, of the FFΤ notably but also the sum etc.
    • numpy.
    • scipy (with multithreading).
    • jax
    • jaxjit : jax with jit.
  2. The formula of the gradient.
    • naive : the naive one $\nabla J(x) = 2 H^T(H x - y)$ coded by hand.
    • auto : gradient generated by JAX from $J(x)$ with grad function. The jax or jaxjit backend are obligatory in that case.
    • brain : The optimized one $\nabla J(x) = 2 W F^H |Λ|^2 F W x - b$ coded by hand.

Conclusion

  • Good library (scipy) with hand optimized code (brain) beats jax in all cases.
  • No brain mode with jax and jit beats numpy code in all cases.
  • The is no good reason to not use jit. JAX without jit and with autodiff is the slowest.
  • Best jax still 1.8 slower than scipy + brain.

Run

To run the benchmark (depends on numpy, scipy, jax[cpu], ipython and matplotlib)

  1. Install poetry
  2. Clone the gist
  3. Run in the directory
poetry install
  1. Then run the benchmark
poetry run ipython3 bench.py

F. Orieux (L2S, Univ. Paris-Saclay) francois.orieux@universite-paris-saclay.fr

"""This Gist compares JAX, Numpy and Scipy. The comparison is about the computation
time (CPU only) of gradient of the criterion
$$J(x) = \|y - H x\|^2$$
where $H = F^* \Lambda F W$ is a linear operator with $F$ the Discrete Fourier
transform, $\Lambda$ a complex diagonal matrix and $W$ a diagonal matrix.
The gradient is
$$\\Nabla J (x) = 2 H^T(Hx - y)$$.
Second we can have a optimized version thank to acces to mathematical expression
$$\\Nabla J (x) = 2 W F^* |\Lambda|^2 F W x - b$$
with $b = H^T t$.
- `jax`, `numpy` and `scipy` are library used for numerical computation (DFT, sum, ...).
- With `auto`, the gradient is generated by JAX from $J(x)$ with `grad` function..
- With `naive`, the gradient is coded by hand.
- With `jit`, JAX code is jitted.
- With `brain`, the gradient is coded by hand in it's optimized version.
To run the benchmark
1. Install [poetry](https://python-poetry.org/)
2. Clone the gist
3. Run in the directory
```
poetry install
```
4. Then run the benchmark
```
poetry run ipython3 bench.py
5. The script save the figure in 'bench_jax_numpy_scipy.jgp'
```
F. Orieux (L2S, Univ. Paris-Saclay) <francois.orieux@universite-paris-saclay.fr>
"""
import time
import jax
import jax.numpy as jnp
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.fft
from IPython import get_ipython
from jax import grad, jit
ipython = get_ipython()
plt.style.use("tableau-colorblind10")
#%% \
# Jax version of H, H^T, crit and grad coded by hand
def jax_H(x, w, h):
return jnp.fft.irfft2(h * jnp.fft.rfft2(x * w), s=x.shape)
def jax_Ht(x, w, h):
return w * jnp.fft.irfft2(h.conj() * jnp.fft.rfft2(x), s=x.shape)
def jax_crit(x, w, h, y):
return jnp.sum(jnp.abs(y - jax_H(x, w, h)) ** 2)
def jax_hand(x, w, h, y):
return 2 * jax_Ht(jax_H(x, w, h) - y, w, h)
def jax_hand_opt(x, w, h2, hty):
return 2 * (w * jnp.fft.irfft2(h2 * jnp.fft.rfft2(x * w), s=x.shape) - hty)
#%% \
# Numpy version of H, H^T, crit and grad coded by hand
def np_H(x, w, h):
return np.fft.irfft2(h * np.fft.rfft2(x * w), s=x.shape)
def np_Ht(x, w, h):
return w * np.fft.irfft2(h.conj() * np.fft.rfft2(x), s=x.shape)
def np_hand(x, w, h, y):
return 2 * np_Ht(np_H(x, w, h) - y, w, h)
def np_hand_opt(x, w, h, y):
return 2 * (w * np.fft.irfft2(h2 * np.fft.rfft2(x * w), s=x.shape) - hty)
#%% \
# Scipy version of H, H^T, crit and grad coded by hand
def sp_H(x, w, h):
return sp.fft.irfft2(h * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1)
def sp_Ht(x, w, h):
return w * sp.fft.irfft2(
h.conj() * sp.fft.rfft2(x, workers=-1), s=x.shape, workers=-1
)
def sp_hand(x, w, h, y):
return 2 * sp_Ht(sp_H(x, w, h) - y, w, h)
def sp_hand_opt(x, w, h2, hty):
return 2 * (
w * sp.fft.irfft2(h2 * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1)
- hty
)
#%% \
shape = (512, 512)
x_np = np.ones(shape, dtype=np.float32)
y_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32)
w_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32)
h_np = np.asarray(np.fft.rfft2(np.ones((5, 5)), shape)).astype(np.complex64)
h2_np = np.abs(h_np) ** 2
hty_np = sp_Ht(x_np, w_np, h_np)
x = jnp.array(x_np)
y = jnp.array(y_np)
w = jnp.array(w_np)
h = jnp.array(h_np)
h2 = jnp.abs(h) ** 2
hty = jax_Ht(x, w, h)
#%% Auto. diff
jax_auto = grad(jax_crit)
print(
"Auto allclose to Hand ? -> ",
np.allclose(jax_auto(x, w, h, y), jax_hand(x, w, h, y), rtol=1e-3, atol=1e-6),
)
#%% Jit
jit_jax_hand = jit(jax_hand)
jit_jax_hand_opt = jit(jax_hand_opt)
jit_jax_auto = jit(jax_auto)
#%% Timer
print("jit jax hand")
T_jit_jax_hand = ipython.run_line_magic("timeit", "-o jit_jax_hand(x, w, h, y)")
print("jax hand")
T_jax_hand = ipython.run_line_magic("timeit", "-o jax_hand(x, w, h, y)")
print("jit jax auto")
T_jit_jax_auto = ipython.run_line_magic("timeit", "-o jit_jax_auto(x, w, h, y)")
print("jax auto")
T_jax_auto = ipython.run_line_magic("timeit", "-o jax_auto(x, w, h, y)")
print("numpy hand")
T_np_hand = ipython.run_line_magic("timeit", "-o np_hand(x_np, w_np, h_np, y_np)")
print("scipy hand")
T_sp_hand = ipython.run_line_magic("timeit", "-o sp_hand(x_np, w_np, h_np, y_np)")
print("scipy hand opt")
T_sp_hand_opt = ipython.run_line_magic(
"timeit", "-o sp_hand_opt(x_np, w_np, h2_np, hty_np)"
)
print("numpy hand opt")
T_np_hand_opt = ipython.run_line_magic(
"timeit", "-o np_hand_opt(x_np, w_np, h2_np, hty_np)"
)
print("jit jax hand opt")
T_jit_jax_hand_opt = ipython.run_line_magic(
"timeit", "-o jit_jax_hand_opt(x, w, h2, hty)"
)
#%% \
# The figure
timer = sorted(
[
(T_np_hand, "numpy", "naive"),
(T_jax_auto, "jax", "auto"),
(T_jax_hand, "jax", "naive"),
(T_jit_jax_auto, "jaxjit", "auto"),
(T_jit_jax_hand, "jaxjit", "naive"),
(T_sp_hand, "scipy", "naive"),
(T_sp_hand_opt, "scipy", "brain"),
(T_np_hand_opt, "numpy", "brain"),
(T_jit_jax_hand_opt, "jaxjit", "brain"),
],
key=lambda i: i[0].average,
reverse=True,
)
colorcode = {
"numpy": "+",
"jax": "/",
"jaxjit": "//",
"scipy": "*",
"naive": "C1",
"auto": "C2",
"brain": "C3",
}
average = [(t.average, lib, code) for (t, lib, code) in timer]
std = [(t.stdev, lib, code) for (t, lib, code) in timer]
fastest = average[-1]
text_above = [f"x {t[0] / fastest[0]:.1f}" for t in average]
fig = plt.figure(1, figsize=(1.6 * 8, 8))
plt.clf()
ax = fig.add_subplot(111)
ax.grid(visible=False)
bar = ax.bar(
[it[1] + " " + it[2] for it in average],
[it[0] for it in average],
zorder=3,
hatch=[colorcode[it[1]] for it in average],
color=[colorcode[it[2]] for it in average],
yerr=[it[0] for it in std],
)
ax.bar_label(bar, text_above)
ax.set_title(
"CPU Time [s] (lower is better, average on 100 runs) for evaluation of\n"
+ r"$\nabla J(x) = 2 H^T (H \cdot x - y)$ "
+ "\n when \n"
+ r"$H = F^* \Lambda F W x$ ($F$ is DFT, $\Lambda$ complex diag and $W$ diag)."
+ f"\n Version: JAX {jax.__version__}, Numpy {np.__version__}, Scipy {sp.__version__}"
+ "\n F. Orieux (L2S) francois.orieux@universite-paris-saclay.fr"
)
# plt.text(
# 0.15,
# 0.75,
# "Various test\n"
# "- jax, numpy and scipy are library to compute the Fourier transform\n"
# "- with 'auto', the gradient is determined by JAX from "
# + r"$J(x) = ||H x - y||^2$ "
# + "\n"
# "- with 'hand', the gradient is determined by hand\n"
# "- with 'jit', JAX code is jitted\n"
# "- with 'opt', we use the mathematical expression to avoid unecessary FFT\n"
# "\n"
# transform=ax.transAxes,
# fontdict={"size": 14},
# bbox=dict(
# boxstyle="round",
# facecolor="white",
# # ec=(1.0, 0.5, 0.5),
# # fc=(1.0, 0.8, 0.8),
# ),
# )
# for rect, accel in zip(bar, text_above):
# height = rect.get_height()
# plt.text(
# rect.get_x() + rect.get_width() / 2.0,
# height,
# accel,
# ha="center",
# va="top",
# color="white",
# fontdict={"size": 14, "weight": "bold"},
# )
leglib1 = mpatches.Patch(facecolor=colorcode["naive"], label="Naive")
leglib2 = mpatches.Patch(facecolor=colorcode["auto"], label="Autodiff")
leglib3 = mpatches.Patch(facecolor=colorcode["brain"], label="Brain")
legcode1 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["numpy"], label="Numpy"
)
legcode2 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["scipy"], label="Scipy"
)
legcode3 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["jax"], label="JAX"
)
legcode4 = mpatches.Patch(
facecolor="white", edgecolor="black", hatch=colorcode["jaxjit"], label="JAX + jit"
)
ax.legend(
handles=[legcode1, legcode2, legcode3, legcode4, leglib1, leglib2, leglib3],
ncols=2,
loc="upper right",
handleheight=0.7 * 4,
handlelength=2.0 * 4,
)
plt.tight_layout()
plt.savefig("bench_jax_numpy_scipy.jpg")
[tool.poetry]
name = "jax_vs_np_vs_sp"
version = "0.1.0"
description = ""
authors = ["François Orieux <orieux@l2s.centralesupelec.fr>"]
license = "Public Domaine"
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
numpy = "^1.22.3"
scipy = "^1.8.0"
jax = {extras = ["cpu"], version = "^0.3.1"}
ipython = "^8.1.1"
matplotlib = "^3.5.1"
[tool.poetry.dev-dependencies]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment