Skip to content

Instantly share code, notes, and snippets.

@CamDavidsonPilon
Last active August 28, 2019 20:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CamDavidsonPilon/4348d4722bcfb7ae4232848c11d24a60 to your computer and use it in GitHub Desktop.
Save CamDavidsonPilon/4348d4722bcfb7ae4232848c11d24a60 to your computer and use it in GitHub Desktop.
initial thoughts on jax
  • vmap is a very general function, but like einsum, I end up trying a bunch of permutations before it works the way I want. More documentation and examples, or higher order functions, would be helpful.

  • debugging is much more difficult than in autograd. Ex: tracking down NaNs is harder, and inspecting variables in jax is not possible?

  • It's as fast as advertised. jit is pretty impressive.

  • stax is a neat little sublibrary, I'd like to see more developer there, but I understand the possible scope-creep.

  • I love the idea of riding the upgrade train of Jax, XLA and GPUs.

  • I see a lot¹ of examples using internal Jax APIs and my code doesn't, so that gives me pause. Am I missing something, or are more higher order functions needed?

  • 🆕 Is vectorize the right API? I'm not sure. Perhaps some common patterns could be extracted into functions. I had a lot of trouble with trying to duplicate elementwise_grad in grad + vmap primitives. It was much easier in autograd. After reading https://nbviewer.jupyter.org/github/google/jax/blob/master/notebooks/gufuncs.ipynb, this quote stood out:

    Instead of needing to think about broadcasting while writing the entire function, we can write the function assuming the input is always a vector.

    This actually makes things clearer - and then rereading the vectorize semantics made more sense. I think vmap's API is tougher to decipher between the two.

  • 🆕 Examples of internal APIs:

¹ On further inspection, maybe I was confusing examples with something else. The reality is better than I imagined.

@mattjj
Copy link

mattjj commented Aug 28, 2019

Thanks so much for writing this up! It's extremely helpful to us.

One quick thought: jax.ops isn't meant to be internal. There's an explanation of what it is at its documentation page. Any advice for how we can be clearer about what's internal vs external? (Our current best answer is that everything documented at jax.readthedocs.io is external, though I'm not sure if that's a good enough answer!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment