Skip to content

Instantly share code, notes, and snippets.

@ingmarschuster
Last active June 30, 2021 13:47
Show Gist options
  • Save ingmarschuster/77b0221acbe599f8f42aa631a0bc7002 to your computer and use it in GitHub Desktop.
Save ingmarschuster/77b0221acbe599f8f42aa631a0bc7002 to your computer and use it in GitHub Desktop.
Testing jax: vmap vs vectorized code. JIT makes speed difference vanish.

Without jax.jit:

Squared euclidean distance:

  • vmap 1.17 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  • vectorized manualy 435 µs ± 473 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Softmax:

  • vmap 925 µs ± 1.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  • vectorized manualy 273 µs ± 408 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

With jax.jit:

Squared euclidean distance:

  • vmap 2.12 µs ± 6.95 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
  • vectorized manualy 2.16 µs ± 7.02 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Softmax:

  • vmap 2.09 µs ± 4.39 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
  • vectorized manualy 2.1 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
import jax.numpy as np
from jax import jit, vmap, pmap
from timeit import timeit
from functools import partial
def bkern(kernel, func = vmap):
return lambda x, y: func(lambda x1: func(lambda y1: kernel(x1, y1))(x))(y)
def veckern(kernel):
return bkern(kernel, vmap)
def parkern(kernel):
return bkern(kernel, pmap)
def se(a, b):
a_sumrows = np.einsum('ij,ij->i', a, a)
b_sumrows = np.einsum('ij,ij->i', b, b)
return a_sumrows[:, np.newaxis] + b_sumrows - 2 * a @ b.T
vse = (veckern(lambda a,b:np.sum((a-b)**2)))
def sm(x: np.ndarray, ax = 0) -> np.ndarray:
"""Vector-wise softmax transform."""
return np.exp(x) / np.sum(np.exp(x), axis = ax)
vsm = (vmap(lambda x: np.exp(x) / np.sum(np.exp(x))))
print("Squared euclidean distance:")
print("vmap ",end="")
%timeit vse(a, a)
print("vectorized manualy ",end="")
%timeit se(a, a)
print("Softmax:")
print("vmap ",end="")
%timeit vsm(a)
print("vectorized manualy ",end="")
%timeit sm(a)
vse, se, vsm, sm = [jit(f) for f in (vse, se, vsm, sm)]
print("Squared euclidean distance:")
print("vmap ",end="")
%timeit vse(a, a)
print("vectorized manualy ",end="")
%timeit se(a, a)
print("Softmax:")
print("vmap ",end="")
%timeit vsm(a)
print("vectorized manualy ",end="")
%timeit sm(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment