Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created July 11, 2022 11:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/d9822afd27293b1da25602f06a09cf87 to your computer and use it in GitHub Desktop.
Save mwitiderrick/d9822afd27293b1da25602f06a09cf87 to your computer and use it in GitHub Desktop.
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
@jax.jit
def vmap_batched_apply_matrix(v_batched):
return jax.vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment