JAX released a persistent compilation cache for TPU VMs! When enabled, the cache writes compiled JAX computations to disk so they don’t have to be re-compiled the next time you start your JAX program. This can save startup time if any of y’all have long compilation times.
First upgrade to the latest jax release:
pip install -U "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Then use the following to enable the cache in your jax code:
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("/path/to/cache/directory")
This will create a new directory if not present, and reuse any existing cache files if present.
The cache implements an LRU eviction policy to prevent using up all your disk space. It defaults to a max size of 32GiB, but you can adjust this by passing max_cache_size_bytes=<bytes>
to initialize_cache.
Note, this also works on GPU, but isn't enabled by default.