Skip to content

Instantly share code, notes, and snippets.

@shawwn
Last active January 2, 2024 15:46
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawwn/16d89ea5121d10214459238225453b13 to your computer and use it in GitHub Desktop.
Save shawwn/16d89ea5121d10214459238225453b13 to your computer and use it in GitHub Desktop.
JAX persistent compilation cache

JAX released a persistent compilation cache for TPU VMs! When enabled, the cache writes compiled JAX computations to disk so they don’t have to be re-compiled the next time you start your JAX program. This can save startup time if any of y’all have long compilation times.

First upgrade to the latest jax release:

pip install -U "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Then use the following to enable the cache in your jax code:

from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache("/path/to/cache/directory")

This will create a new directory if not present, and reuse any existing cache files if present.

The cache implements an LRU eviction policy to prevent using up all your disk space. It defaults to a max size of 32GiB, but you can adjust this by passing max_cache_size_bytes=<bytes> to initialize_cache.

@nouiz
Copy link

nouiz commented May 12, 2023

Note, this also works on GPU, but isn't enabled by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment