Skip to content

Instantly share code, notes, and snippets.

@jpivarski
Last active July 14, 2024 17:22
Show Gist options
  • Save jpivarski/da343abd8024834ee8c5aaba691aafc7 to your computer and use it in GitHub Desktop.
Save jpivarski/da343abd8024834ee8c5aaba691aafc7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@asafhaddad
Copy link

This is wonderful, thanks for sharing!

@mjbaldwin
Copy link

Thank you @jpivarski -- this notebook has been extremely helpful in getting me started using Jax for fractal computation on a GPU.

However I just wanted to let you know (and any others) that the Jax routine performs loop unrolling, which converts the loop of 20 iterations into code that simply gets repeated and compiled as 20 consecutive blocks of code.

This isn't noticeable when only performing 20 iterations, but quickly becomes unmanageable with a more realistic fractal calculation of, say, 1,000 iterations -- which turns out to take something like 60 seconds to compile for its first run.

After I asked elsewhere for help, I was informed the solution was to use lax.fori_loop instead of a traditional Python loop, which ensures a genuine loop is actually compiled and performed.

And while this compiled almost instantly, I'll note that this came at a post-compilation performance cost -- instead of the 1,000 iterations executing in ~80 ms, it took ~1,000 ms.

However, I then managed to "get the best of both worlds" by combining the two methods, and doing 10 loops of lax.fori_loop which, within it, did 100 iterations of a traditional Python loop (that was unrolled for performance).

I'm sure you probably don't want to do anything like update this guide to compare 1,000 iterations instead of 20, but I thought you might want to know that this does add a wrinkle at least to Jax, and I don't know about the other GPU solutions as well.

@jpivarski
Copy link
Author

Thanks for pointing this out! This is worth knowing and I'll mention it whenever I show this example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment