Add a tqdm progress bar to a JAX scan or fori_loop. Code from this blog post: This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
from tqdm.auto import tqdm | |
from jax import jit, vmap, grad, random, lax, partial | |
import jax.numpy as jnp | |
from jax.experimental import host_callback | |
def progress_bar_scan(num_samples, message=None): | |
"Progress bar for a JAX scan" | |
if message is None: | |
message = f"Running for {num_samples:,} iterations" | |
tqdm_bars = {} | |
if num_samples > 20: | |
print_rate = int(num_samples / 20) | |
else: | |
print_rate = 1 | |
remainder = num_samples % print_rate | |
def _define_tqdm(arg, transform): | |
tqdm_bars[0] = tqdm(range(num_samples)) | |
tqdm_bars[0].set_description(message, refresh=False) | |
def _update_tqdm(arg, transform): | |
tqdm_bars[0].update(arg) | |
def _update_progress_bar(iter_num): | |
"Updates tqdm progress bar of a JAX scan or loop" | |
_ = lax.cond( | |
iter_num == 0, | |
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
_ = lax.cond( | |
# update tqdm every multiple of `print_rate` except at the end | |
(iter_num % print_rate == 0) & (iter_num != num_samples-remainder), | |
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
_ = lax.cond( | |
# update tqdm by `remainder` | |
iter_num == num_samples-remainder, | |
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
def _close_tqdm(arg, transform): | |
tqdm_bars[0].close() | |
def close_tqdm(result, iter_num): | |
return lax.cond( | |
iter_num == num_samples-1, | |
lambda _: host_callback.id_tap(_close_tqdm, None, result=result), | |
lambda _: result, | |
operand=None, | |
) | |
def _progress_bar_scan(func): | |
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`. | |
Note that `body_fun` must either be looping over `np.arange(num_samples)`, | |
or be looping over a tuple who's first element is `np.arange(num_samples)` | |
This means that `iter_num` is the current iteration number | |
""" | |
def wrapper_progress_bar(carry, x): | |
if type(x) is tuple: | |
iter_num, *_ = x | |
else: | |
iter_num = x | |
_update_progress_bar(iter_num) | |
result = func(carry, x) | |
return close_tqdm(result, iter_num) | |
return wrapper_progress_bar | |
return _progress_bar_scan | |
# ======== | |
# define Gaussian log-posterior | |
@jit | |
def log_posterior(x): | |
return -0.5*jnp.dot(x,x) | |
grad_log_post = jit(grad(log_posterior)) | |
# ======== | |
# define ULA sampler | |
@partial(jit, static_argnums=(2,)) | |
def ula_kernel(key, param, grad_log_post, dt): | |
key, subkey = random.split(key) | |
paramGrad = grad_log_post(param) | |
param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape)) | |
return key, param | |
@partial(jit, static_argnums=(1,2)) | |
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0): | |
"ULA sampler with progress bar" | |
@progress_bar_scan(num_samples) | |
def ula_step(carry, iter_num): | |
key, param = carry | |
key, param = ula_kernel(key, param, grad_log_post, dt) | |
return (key, param), param | |
carry = (key, x_0) | |
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples)) | |
return samples | |
# ======== | |
# run sampler | |
key = random.PRNGKey(0) | |
num_samples = 1000000 | |
dt = 1e-2 | |
samples = ula_sampler_pbar(key, grad_log_post, num_samples, dt, jnp.zeros(100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment