Skip to content

Instantly share code, notes, and snippets.

@HamletWantToCode
Last active September 21, 2019 17:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HamletWantToCode/d98945baa41c0ad4c491c9b3822b2d51 to your computer and use it in GitHub Desktop.
Save HamletWantToCode/d98945baa41c0ad4c491c9b3822b2d51 to your computer and use it in GitHub Desktop.
vectorize function in jax #JAX
from jax import vmap
def vectorize(kernel):
def wraps(*args):
mv_kernel = vmap(kernel, (None, 0, None), 0)
mm_kernel = vmap(mv_kernel, (None, None, 0), 1)
return mm_kernel(*args)
return wraps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment