-
-
Save jpivarski/da343abd8024834ee8c5aaba691aafc7 to your computer and use it in GitHub Desktop.
Thank you @jpivarski -- this notebook has been extremely helpful in getting me started using Jax for fractal computation on a GPU.
However I just wanted to let you know (and any others) that the Jax routine performs loop unrolling, which converts the loop of 20 iterations into code that simply gets repeated and compiled as 20 consecutive blocks of code.
This isn't noticeable when only performing 20 iterations, but quickly becomes unmanageable with a more realistic fractal calculation of, say, 1,000 iterations -- which turns out to take something like 60 seconds to compile for its first run.
After I asked elsewhere for help, I was informed the solution was to use lax.fori_loop instead of a traditional Python loop, which ensures a genuine loop is actually compiled and performed.
And while this compiled almost instantly, I'll note that this came at a post-compilation performance cost -- instead of the 1,000 iterations executing in ~80 ms, it took ~1,000 ms.
However, I then managed to "get the best of both worlds" by combining the two methods, and doing 10 loops of lax.fori_loop
which, within it, did 100 iterations of a traditional Python loop (that was unrolled for performance).
I'm sure you probably don't want to do anything like update this guide to compare 1,000 iterations instead of 20, but I thought you might want to know that this does add a wrinkle at least to Jax, and I don't know about the other GPU solutions as well.
Thanks for pointing this out! This is worth knowing and I'll mention it whenever I show this example.
This is wonderful, thanks for sharing!