Skip to content

Instantly share code, notes, and snippets.

@jeremiecoullon
jeremiecoullon / Jax_progress_bar_tqdm.py
Last active October 12, 2022 14:51
Add a tqdm progress bar to a JAX scan or fori_loop. Code from this blog post: This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
from tqdm.auto import tqdm
from jax import jit, vmap, grad, random, lax, partial
import jax.numpy as jnp
from jax.experimental import host_callback
def progress_bar_scan(num_samples, message=None):
"Progress bar for a JAX scan"
if message is None:
message = f"Running for {num_samples:,} iterations"
@jeremiecoullon
jeremiecoullon / jax_progress_bar_basic.py
Last active February 7, 2021 19:15
Add a basic progress bar to a JAX scan or fori_loop. This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
import jax.numpy as jnp
from jax import jit, vmap, grad, random, lax, partial
from jax.experimental import host_callback
# ========
# define progress bar
def _print_consumer(arg, transform):
iter_num, num_samples = arg