Last active
February 7, 2021 19:15
-
-
Save jeremiecoullon/4ae89676e650370936200ec04a4e3bef to your computer and use it in GitHub Desktop.
Add a basic progress bar to a JAX scan or fori_loop. This code is from this blog post: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
from jax import jit, vmap, grad, random, lax, partial | |
from jax.experimental import host_callback | |
# ======== | |
# define progress bar | |
def _print_consumer(arg, transform): | |
iter_num, num_samples = arg | |
print(f"Iteration {iter_num:,} / {num_samples:,}") | |
@jit | |
def progress_bar(arg, result): | |
""" | |
Print progress of a scan/loop only if the iteration number is a multiple of the print_rate | |
Usage: `carry = progress_bar((iter_num + 1, num_samples, print_rate), carry)` | |
Pass in `iter_num + 1` so that counting starts at 1 and ends at `num_samples` | |
""" | |
iter_num, num_samples, print_rate = arg | |
result = lax.cond( | |
iter_num % print_rate==0, | |
lambda _: host_callback.id_tap(_print_consumer, (iter_num, num_samples), result=result), | |
lambda _: result, | |
operand = None) | |
return result | |
def progress_bar_scan(num_samples): | |
""" | |
Decorator that adds a progress bar to `body_fun` used in `lax.scan`. | |
Note that `body_fun` must be looping over `jnp.arange(num_samples)`. | |
This means that `iter_num` is the current iteration number | |
""" | |
def _progress_bar_scan(func): | |
print_rate = int(num_samples/10) | |
def wrapper_progress_bar(carry, iter_num): | |
iter_num = progress_bar((iter_num + 1, num_samples, print_rate), iter_num) | |
return func(carry, iter_num) | |
return wrapper_progress_bar | |
return _progress_bar_scan | |
# ======== | |
# define Gaussian log-posterior | |
@jit | |
def log_posterior(x): | |
return -0.5*jnp.dot(x,x) | |
grad_log_post = jit(grad(log_posterior)) | |
# ======== | |
# define ULA sampler | |
@partial(jit, static_argnums=(2,)) | |
def ula_kernel(key, param, grad_log_post, dt): | |
key, subkey = random.split(key) | |
paramGrad = grad_log_post(param) | |
param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape)) | |
return key, param | |
@partial(jit, static_argnums=(1,2)) | |
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0): | |
"ULA sampler with progress bar" | |
print("Compiling..") | |
@progress_bar_scan(num_samples) | |
def ula_step(carry, iter_num): | |
key, param = carry | |
key, param = ula_kernel(key, param, grad_log_post, dt) | |
return (key, param), param | |
carry = (key, x_0) | |
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples)) | |
print("Running:") | |
return samples | |
# ======== | |
# run sampler | |
key = random.PRNGKey(0) | |
num_samples = 100000 | |
dt = 1e-2 | |
_ = ula_sampler_pbar(key, grad_log_post, num_samples, dt, jnp.zeros(1000))[0][0].block_until_ready() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment