|
"""This Gist compares JAX, Numpy and Scipy. The comparison is about the computation |
|
time (CPU only) of gradient of the criterion |
|
|
|
$$J(x) = \|y - H x\|^2$$ |
|
|
|
where $H = F^* \Lambda F W$ is a linear operator with $F$ the Discrete Fourier |
|
transform, $\Lambda$ a complex diagonal matrix and $W$ a diagonal matrix. |
|
|
|
The gradient is |
|
|
|
$$\\Nabla J (x) = 2 H^T(Hx - y)$$. |
|
|
|
Second we can have a optimized version thank to acces to mathematical expression |
|
|
|
$$\\Nabla J (x) = 2 W F^* |\Lambda|^2 F W x - b$$ |
|
|
|
with $b = H^T t$. |
|
|
|
- `jax`, `numpy` and `scipy` are library used for numerical computation (DFT, sum, ...). |
|
- With `auto`, the gradient is generated by JAX from $J(x)$ with `grad` function.. |
|
- With `naive`, the gradient is coded by hand. |
|
- With `jit`, JAX code is jitted. |
|
- With `brain`, the gradient is coded by hand in it's optimized version. |
|
|
|
To run the benchmark |
|
|
|
1. Install [poetry](https://python-poetry.org/) |
|
2. Clone the gist |
|
3. Run in the directory |
|
``` |
|
poetry install |
|
``` |
|
4. Then run the benchmark |
|
``` |
|
poetry run ipython3 bench.py |
|
5. The script save the figure in 'bench_jax_numpy_scipy.jgp' |
|
``` |
|
|
|
F. Orieux (L2S, Univ. Paris-Saclay) <francois.orieux@universite-paris-saclay.fr> |
|
|
|
""" |
|
|
|
import time |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import matplotlib.patches as mpatches |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import scipy as sp |
|
import scipy.fft |
|
from IPython import get_ipython |
|
from jax import grad, jit |
|
|
|
ipython = get_ipython() |
|
plt.style.use("tableau-colorblind10") |
|
|
|
|
|
#%%\ |
|
# Jax version of H, H^T, crit and grad coded by hand |
|
def jax_H(x, w, h): |
|
return jnp.fft.irfft2(h * jnp.fft.rfft2(x * w), s=x.shape) |
|
|
|
|
|
def jax_Ht(x, w, h): |
|
return w * jnp.fft.irfft2(h.conj() * jnp.fft.rfft2(x), s=x.shape) |
|
|
|
|
|
def jax_crit(x, w, h, y): |
|
return jnp.sum(jnp.abs(y - jax_H(x, w, h)) ** 2) |
|
|
|
|
|
def jax_hand(x, w, h, y): |
|
return 2 * jax_Ht(jax_H(x, w, h) - y, w, h) |
|
|
|
|
|
def jax_hand_opt(x, w, h2, hty): |
|
return 2 * (w * jnp.fft.irfft2(h2 * jnp.fft.rfft2(x * w), s=x.shape) - hty) |
|
|
|
|
|
#%%\ |
|
# Numpy version of H, H^T, crit and grad coded by hand |
|
def np_H(x, w, h): |
|
return np.fft.irfft2(h * np.fft.rfft2(x * w), s=x.shape) |
|
|
|
|
|
def np_Ht(x, w, h): |
|
return w * np.fft.irfft2(h.conj() * np.fft.rfft2(x), s=x.shape) |
|
|
|
|
|
def np_hand(x, w, h, y): |
|
return 2 * np_Ht(np_H(x, w, h) - y, w, h) |
|
|
|
|
|
def np_hand_opt(x, w, h, y): |
|
return 2 * (w * np.fft.irfft2(h2 * np.fft.rfft2(x * w), s=x.shape) - hty) |
|
|
|
|
|
#%%\ |
|
# Scipy version of H, H^T, crit and grad coded by hand |
|
def sp_H(x, w, h): |
|
return sp.fft.irfft2(h * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1) |
|
|
|
|
|
def sp_Ht(x, w, h): |
|
return w * sp.fft.irfft2( |
|
h.conj() * sp.fft.rfft2(x, workers=-1), s=x.shape, workers=-1 |
|
) |
|
|
|
|
|
def sp_hand(x, w, h, y): |
|
return 2 * sp_Ht(sp_H(x, w, h) - y, w, h) |
|
|
|
|
|
def sp_hand_opt(x, w, h2, hty): |
|
return 2 * ( |
|
w * sp.fft.irfft2(h2 * sp.fft.rfft2(x * w, workers=-1), s=x.shape, workers=-1) |
|
- hty |
|
) |
|
|
|
|
|
#%%\ |
|
shape = (512, 512) |
|
x_np = np.ones(shape, dtype=np.float32) |
|
y_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32) |
|
w_np = np.asarray(np.random.standard_normal(shape)).astype(np.float32) |
|
h_np = np.asarray(np.fft.rfft2(np.ones((5, 5)), shape)).astype(np.complex64) |
|
h2_np = np.abs(h_np) ** 2 |
|
hty_np = sp_Ht(x_np, w_np, h_np) |
|
|
|
x = jnp.array(x_np) |
|
y = jnp.array(y_np) |
|
w = jnp.array(w_np) |
|
h = jnp.array(h_np) |
|
h2 = jnp.abs(h) ** 2 |
|
hty = jax_Ht(x, w, h) |
|
|
|
|
|
#%% Auto. diff |
|
jax_auto = grad(jax_crit) |
|
print( |
|
"Auto allclose to Hand ? -> ", |
|
np.allclose(jax_auto(x, w, h, y), jax_hand(x, w, h, y), rtol=1e-3, atol=1e-6), |
|
) |
|
|
|
#%% Jit |
|
jit_jax_hand = jit(jax_hand) |
|
jit_jax_hand_opt = jit(jax_hand_opt) |
|
jit_jax_auto = jit(jax_auto) |
|
|
|
#%% Timer |
|
print("jit jax hand") |
|
T_jit_jax_hand = ipython.run_line_magic("timeit", "-o jit_jax_hand(x, w, h, y)") |
|
print("jax hand") |
|
T_jax_hand = ipython.run_line_magic("timeit", "-o jax_hand(x, w, h, y)") |
|
print("jit jax auto") |
|
T_jit_jax_auto = ipython.run_line_magic("timeit", "-o jit_jax_auto(x, w, h, y)") |
|
print("jax auto") |
|
T_jax_auto = ipython.run_line_magic("timeit", "-o jax_auto(x, w, h, y)") |
|
print("numpy hand") |
|
T_np_hand = ipython.run_line_magic("timeit", "-o np_hand(x_np, w_np, h_np, y_np)") |
|
print("scipy hand") |
|
T_sp_hand = ipython.run_line_magic("timeit", "-o sp_hand(x_np, w_np, h_np, y_np)") |
|
|
|
print("scipy hand opt") |
|
T_sp_hand_opt = ipython.run_line_magic( |
|
"timeit", "-o sp_hand_opt(x_np, w_np, h2_np, hty_np)" |
|
) |
|
print("numpy hand opt") |
|
T_np_hand_opt = ipython.run_line_magic( |
|
"timeit", "-o np_hand_opt(x_np, w_np, h2_np, hty_np)" |
|
) |
|
print("jit jax hand opt") |
|
T_jit_jax_hand_opt = ipython.run_line_magic( |
|
"timeit", "-o jit_jax_hand_opt(x, w, h2, hty)" |
|
) |
|
|
|
|
|
#%%\ |
|
# The figure |
|
|
|
timer = sorted( |
|
[ |
|
(T_np_hand, "numpy", "naive"), |
|
(T_jax_auto, "jax", "auto"), |
|
(T_jax_hand, "jax", "naive"), |
|
(T_jit_jax_auto, "jaxjit", "auto"), |
|
(T_jit_jax_hand, "jaxjit", "naive"), |
|
(T_sp_hand, "scipy", "naive"), |
|
(T_sp_hand_opt, "scipy", "brain"), |
|
(T_np_hand_opt, "numpy", "brain"), |
|
(T_jit_jax_hand_opt, "jaxjit", "brain"), |
|
], |
|
key=lambda i: i[0].average, |
|
reverse=True, |
|
) |
|
|
|
colorcode = { |
|
"numpy": "+", |
|
"jax": "/", |
|
"jaxjit": "//", |
|
"scipy": "*", |
|
"naive": "C1", |
|
"auto": "C2", |
|
"brain": "C3", |
|
} |
|
|
|
average = [(t.average, lib, code) for (t, lib, code) in timer] |
|
std = [(t.stdev, lib, code) for (t, lib, code) in timer] |
|
|
|
fastest = average[-1] |
|
text_above = [f"x {t[0] / fastest[0]:.1f}" for t in average] |
|
|
|
fig = plt.figure(1, figsize=(1.6 * 8, 8)) |
|
plt.clf() |
|
ax = fig.add_subplot(111) |
|
ax.grid(visible=False) |
|
bar = ax.bar( |
|
[it[1] + " " + it[2] for it in average], |
|
[it[0] for it in average], |
|
zorder=3, |
|
hatch=[colorcode[it[1]] for it in average], |
|
color=[colorcode[it[2]] for it in average], |
|
yerr=[it[0] for it in std], |
|
) |
|
ax.bar_label(bar, text_above) |
|
ax.set_title( |
|
"CPU Time [s] (lower is better, average on 100 runs) for evaluation of\n" |
|
+ r"$\nabla J(x) = 2 H^T (H \cdot x - y)$ " |
|
+ "\n when \n" |
|
+ r"$H = F^* \Lambda F W x$ ($F$ is DFT, $\Lambda$ complex diag and $W$ diag)." |
|
+ f"\n Version: JAX {jax.__version__}, Numpy {np.__version__}, Scipy {sp.__version__}" |
|
+ "\n F. Orieux (L2S) francois.orieux@universite-paris-saclay.fr" |
|
) |
|
# plt.text( |
|
# 0.15, |
|
# 0.75, |
|
# "Various test\n" |
|
# "- jax, numpy and scipy are library to compute the Fourier transform\n" |
|
# "- with 'auto', the gradient is determined by JAX from " |
|
# + r"$J(x) = ||H x - y||^2$ " |
|
# + "\n" |
|
# "- with 'hand', the gradient is determined by hand\n" |
|
# "- with 'jit', JAX code is jitted\n" |
|
# "- with 'opt', we use the mathematical expression to avoid unecessary FFT\n" |
|
# "\n" |
|
# transform=ax.transAxes, |
|
# fontdict={"size": 14}, |
|
# bbox=dict( |
|
# boxstyle="round", |
|
# facecolor="white", |
|
# # ec=(1.0, 0.5, 0.5), |
|
# # fc=(1.0, 0.8, 0.8), |
|
# ), |
|
# ) |
|
# for rect, accel in zip(bar, text_above): |
|
# height = rect.get_height() |
|
# plt.text( |
|
# rect.get_x() + rect.get_width() / 2.0, |
|
# height, |
|
# accel, |
|
# ha="center", |
|
# va="top", |
|
# color="white", |
|
# fontdict={"size": 14, "weight": "bold"}, |
|
# ) |
|
leglib1 = mpatches.Patch(facecolor=colorcode["naive"], label="Naive") |
|
leglib2 = mpatches.Patch(facecolor=colorcode["auto"], label="Autodiff") |
|
leglib3 = mpatches.Patch(facecolor=colorcode["brain"], label="Brain") |
|
legcode1 = mpatches.Patch( |
|
facecolor="white", edgecolor="black", hatch=colorcode["numpy"], label="Numpy" |
|
) |
|
legcode2 = mpatches.Patch( |
|
facecolor="white", edgecolor="black", hatch=colorcode["scipy"], label="Scipy" |
|
) |
|
legcode3 = mpatches.Patch( |
|
facecolor="white", edgecolor="black", hatch=colorcode["jax"], label="JAX" |
|
) |
|
legcode4 = mpatches.Patch( |
|
facecolor="white", edgecolor="black", hatch=colorcode["jaxjit"], label="JAX + jit" |
|
) |
|
ax.legend( |
|
handles=[legcode1, legcode2, legcode3, legcode4, leglib1, leglib2, leglib3], |
|
ncols=2, |
|
loc="upper right", |
|
handleheight=0.7 * 4, |
|
handlelength=2.0 * 4, |
|
) |
|
plt.tight_layout() |
|
plt.savefig("bench_jax_numpy_scipy.jpg") |