Skip to content

Instantly share code, notes, and snippets.

View ingmarschuster's full-sized avatar

Ingmar Schuster ingmarschuster

View GitHub Profile
@ingmarschuster
ingmarschuster / results.md
Last active June 30, 2021 13:47
Testing jax: vmap vs vectorized code. JIT makes speed difference vanish.

Without jax.jit:

Squared euclidean distance:

  • vmap 1.17 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  • vectorized manualy 435 µs ± 473 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Softmax:

  • vmap 925 µs ± 1.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  • vectorized manualy 273 µs ± 408 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

With jax.jit: