Skip to content

Instantly share code, notes, and snippets.

@jeremiecoullon
Last active February 7, 2021 19:15
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save jeremiecoullon/4ae89676e650370936200ec04a4e3bef to your computer and use it in GitHub Desktop.
Add a basic progress bar to a JAX scan or fori_loop. This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
import jax.numpy as jnp
from jax import jit, vmap, grad, random, lax, partial
from jax.experimental import host_callback
# ========
# define progress bar
def _print_consumer(arg, transform):
iter_num, num_samples = arg
print(f"Iteration {iter_num:,} / {num_samples:,}")
@jit
def progress_bar(arg, result):
"""
Print progress of a scan/loop only if the iteration number is a multiple of the print_rate
Usage: `carry = progress_bar((iter_num + 1, num_samples, print_rate), carry)`
Pass in `iter_num + 1` so that counting starts at 1 and ends at `num_samples`
"""
iter_num, num_samples, print_rate = arg
result = lax.cond(
iter_num % print_rate==0,
lambda _: host_callback.id_tap(_print_consumer, (iter_num, num_samples), result=result),
lambda _: result,
operand = None)
return result
def progress_bar_scan(num_samples):
"""
Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must be looping over `jnp.arange(num_samples)`.
This means that `iter_num` is the current iteration number
"""
def _progress_bar_scan(func):
print_rate = int(num_samples/10)
def wrapper_progress_bar(carry, iter_num):
iter_num = progress_bar((iter_num + 1, num_samples, print_rate), iter_num)
return func(carry, iter_num)
return wrapper_progress_bar
return _progress_bar_scan
# ========
# define Gaussian log-posterior
@jit
def log_posterior(x):
return -0.5*jnp.dot(x,x)
grad_log_post = jit(grad(log_posterior))
# ========
# define ULA sampler
@partial(jit, static_argnums=(2,))
def ula_kernel(key, param, grad_log_post, dt):
key, subkey = random.split(key)
paramGrad = grad_log_post(param)
param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape))
return key, param
@partial(jit, static_argnums=(1,2))
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0):
"ULA sampler with progress bar"
print("Compiling..")
@progress_bar_scan(num_samples)
def ula_step(carry, iter_num):
key, param = carry
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
carry = (key, x_0)
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples))
print("Running:")
return samples
# ========
# run sampler
key = random.PRNGKey(0)
num_samples = 100000
dt = 1e-2
_ = ula_sampler_pbar(key, grad_log_post, num_samples, dt, jnp.zeros(1000))[0][0].block_until_ready()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment