Skip to content

Instantly share code, notes, and snippets.

@rlouf
Created December 10, 2020 08:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rlouf/65b52b6d3c1915addfe54005017f922d to your computer and use it in GitHub Desktop.
Save rlouf/65b52b6d3c1915addfe54005017f922d to your computer and use it in GitHub Desktop.
2 ways to batch inference
import jax
""" 'Loop over vmaps'
This is what most PPLs do. We synchronize chains at the step level, each step
ends when the `num_chains` have completed their step. We iterate for `num_samples`
steps.
"""
def one_step(rng_key, states):
chain_keys = jax.random.split(rng_key, num_chains)
new_states, _ = jax.vmap(kernel)(chain_keys, states)
return new_states, states
step_keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_states, step_keys)
""" 'vmap the loops'
We run `num_chains` independent chains. This ends when all chains
have run `num_steps`.
"""
def chain(rng_key, initial_state):
"""Moves one chain `num_steps` time."""
def one_step(state, key):
new_state, _ = kernel(key, state)
return new_state, state
step_keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, step_keys)
return states
chain_keys = jax.random.split(rng_key, num_chains)
run = jax.vmap(chain)(keys, initial_states)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment