JAX released a persistent compilation cache for TPU VMs! When enabled, the cache writes compiled JAX computations to disk so they don’t have to be re-compiled the next time you start your JAX program. This can save startup time if any of y’all have long compilation times.
First upgrade to the latest jax release:
pip install -U "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Then use the following to enable the cache in your jax code:
from jax.experimental.compilation_cache import compilation_cache as cc