Skip to content

Instantly share code, notes, and snippets.

@jeremiecoullon
Last active October 12, 2022 14:51
Show Gist options
  • Save jeremiecoullon/f6a658be4c98f8a7fd1710418cca0856 to your computer and use it in GitHub Desktop.
Save jeremiecoullon/f6a658be4c98f8a7fd1710418cca0856 to your computer and use it in GitHub Desktop.
Add a tqdm progress bar to a JAX scan or fori_loop. Code from this blog post: This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
from tqdm.auto import tqdm
from jax import jit, vmap, grad, random, lax, partial
import jax.numpy as jnp
from jax.experimental import host_callback
def progress_bar_scan(num_samples, message=None):
"Progress bar for a JAX scan"
if message is None:
message = f"Running for {num_samples:,} iterations"
tqdm_bars = {}
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1
remainder = num_samples % print_rate
def _define_tqdm(arg, transform):
tqdm_bars[0] = tqdm(range(num_samples))
tqdm_bars[0].set_description(message, refresh=False)
def _update_tqdm(arg, transform):
tqdm_bars[0].update(arg)
def _update_progress_bar(iter_num):
"Updates tqdm progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
lambda _: iter_num,
operand=None,
)
_ = lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != num_samples-remainder),
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num),
lambda _: iter_num,
operand=None,
)
_ = lax.cond(
# update tqdm by `remainder`
iter_num == num_samples-remainder,
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num),
lambda _: iter_num,
operand=None,
)
def _close_tqdm(arg, transform):
tqdm_bars[0].close()
def close_tqdm(result, iter_num):
return lax.cond(
iter_num == num_samples-1,
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
lambda _: result,
operand=None,
)
def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must either be looping over `np.arange(num_samples)`,
or be looping over a tuple who's first element is `np.arange(num_samples)`
This means that `iter_num` is the current iteration number
"""
def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)
return wrapper_progress_bar
return _progress_bar_scan
# ========
# define Gaussian log-posterior
@jit
def log_posterior(x):
return -0.5*jnp.dot(x,x)
grad_log_post = jit(grad(log_posterior))
# ========
# define ULA sampler
@partial(jit, static_argnums=(2,))
def ula_kernel(key, param, grad_log_post, dt):
key, subkey = random.split(key)
paramGrad = grad_log_post(param)
param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape))
return key, param
@partial(jit, static_argnums=(1,2))
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0):
"ULA sampler with progress bar"
@progress_bar_scan(num_samples)
def ula_step(carry, iter_num):
key, param = carry
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
carry = (key, x_0)
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples))
return samples
# ========
# run sampler
key = random.PRNGKey(0)
num_samples = 1000000
dt = 1e-2
samples = ula_sampler_pbar(key, grad_log_post, num_samples, dt, jnp.zeros(100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment