Last active
October 12, 2022 14:51
-
-
Save jeremiecoullon/f6a658be4c98f8a7fd1710418cca0856 to your computer and use it in GitHub Desktop.
Add a tqdm progress bar to a JAX scan or fori_loop. Code from this blog post: This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm.auto import tqdm | |
from jax import jit, vmap, grad, random, lax, partial | |
import jax.numpy as jnp | |
from jax.experimental import host_callback | |
def progress_bar_scan(num_samples, message=None): | |
"Progress bar for a JAX scan" | |
if message is None: | |
message = f"Running for {num_samples:,} iterations" | |
tqdm_bars = {} | |
if num_samples > 20: | |
print_rate = int(num_samples / 20) | |
else: | |
print_rate = 1 | |
remainder = num_samples % print_rate | |
def _define_tqdm(arg, transform): | |
tqdm_bars[0] = tqdm(range(num_samples)) | |
tqdm_bars[0].set_description(message, refresh=False) | |
def _update_tqdm(arg, transform): | |
tqdm_bars[0].update(arg) | |
def _update_progress_bar(iter_num): | |
"Updates tqdm progress bar of a JAX scan or loop" | |
_ = lax.cond( | |
iter_num == 0, | |
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
_ = lax.cond( | |
# update tqdm every multiple of `print_rate` except at the end | |
(iter_num % print_rate == 0) & (iter_num != num_samples-remainder), | |
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
_ = lax.cond( | |
# update tqdm by `remainder` | |
iter_num == num_samples-remainder, | |
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num), | |
lambda _: iter_num, | |
operand=None, | |
) | |
def _close_tqdm(arg, transform): | |
tqdm_bars[0].close() | |
def close_tqdm(result, iter_num): | |
return lax.cond( | |
iter_num == num_samples-1, | |
lambda _: host_callback.id_tap(_close_tqdm, None, result=result), | |
lambda _: result, | |
operand=None, | |
) | |
def _progress_bar_scan(func): | |
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`. | |
Note that `body_fun` must either be looping over `np.arange(num_samples)`, | |
or be looping over a tuple who's first element is `np.arange(num_samples)` | |
This means that `iter_num` is the current iteration number | |
""" | |
def wrapper_progress_bar(carry, x): | |
if type(x) is tuple: | |
iter_num, *_ = x | |
else: | |
iter_num = x | |
_update_progress_bar(iter_num) | |
result = func(carry, x) | |
return close_tqdm(result, iter_num) | |
return wrapper_progress_bar | |
return _progress_bar_scan | |
# ======== | |
# define Gaussian log-posterior | |
@jit | |
def log_posterior(x): | |
return -0.5*jnp.dot(x,x) | |
grad_log_post = jit(grad(log_posterior)) | |
# ======== | |
# define ULA sampler | |
@partial(jit, static_argnums=(2,)) | |
def ula_kernel(key, param, grad_log_post, dt): | |
key, subkey = random.split(key) | |
paramGrad = grad_log_post(param) | |
param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape)) | |
return key, param | |
@partial(jit, static_argnums=(1,2)) | |
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0): | |
"ULA sampler with progress bar" | |
@progress_bar_scan(num_samples) | |
def ula_step(carry, iter_num): | |
key, param = carry | |
key, param = ula_kernel(key, param, grad_log_post, dt) | |
return (key, param), param | |
carry = (key, x_0) | |
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples)) | |
return samples | |
# ======== | |
# run sampler | |
key = random.PRNGKey(0) | |
num_samples = 1000000 | |
dt = 1e-2 | |
samples = ula_sampler_pbar(key, grad_log_post, num_samples, dt, jnp.zeros(100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment