Skip to content

Instantly share code, notes, and snippets.

@amifalk
Created March 13, 2024 15:46
Show Gist options
  • Save amifalk/eb377a243b046105dc00beda79441b22 to your computer and use it in GitHub Desktop.
Save amifalk/eb377a243b046105dc00beda79441b22 to your computer and use it in GitHub Desktop.
Batch SVI
import jax
import jax.random as random
from numpyro.infer import SVI
def batch_svi_setup(model, guide, optim, loss,
model_args,
*,
rng_key,
svi_init_params=None,
n_init=1,
batch_args=None):
"""
Set up SVI for batched inference over model arguments and/or different parameter initializations.
To minimize the complexity of this implementation, the model and guide may not have keyword arguments.
Args:
model: Python callable with Pyro primitives for the model.
guide: Python callable with Pyro primitives for the guide.
optim: An instance of :class:`~numpyro.optim._NumpyroOptim`, a ``jax.example_libraries.optimizers.Optimizer``
or an Optax ``GradientTransformation``. If you pass an Optax optimizer it will automatically be wrapped
using :func:`numpyro.optim.optax_to_numpyro`.
loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
model_args: tuple of arguments for the model/guide.
rng_key (jax.random.PRNGKey): random number generator seed to be used for param initialization.
svi_init_params (Callable, optional): Optional callable that initializes the parameters for guide sites.
This is necessary to get different initial parameters for multiple initializations. If specified, the function
signature must match `fn(rng_key, *args)` and must return a dictionary mapping site names in the guide
to initial values. Defaults to None.
n_init (int, optional): The number of batched parameter initializations to initialize SVI with. Defaults to 0.
batch_args (int | tuple, optional): Optional int/tuple specifying the `in_axes` of model_args to vmap over.
If None, will not batch over model_args. Defaults to None.
Returns:
tuple of (batched initial `SVIState`, batched_get_params_fn, batched_update_fn)
**Example**
```python
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
def model(obs):
a = numpyro.sample('mu', dist.Normal(0, 1))
with numpyro.plate('n_obs', 100):
numpyro.sample('obs', dist.Normal(a, 1), obs=obs)
obs = dist.Normal(loc=jnp.array([3., 6., 9.]), scale=1).sample(jax.random.PRNGKey(0), (100,)).T
guide = AutoDelta(model)
optim = numpyro.optim.Adam(step_size=.01)
def init_params(rng_key, *args):
return {'mu_auto_loc': dist.Normal(0, 1).sample(rng_key)}
state, get_params, update = batch_svi_setup(model, guide, optim, Trace_ELBO(),
model_args=(obs,),
rng_key=jax.random.PRNGKey(1),
svi_init_params=init_params,
n_init=2,
batch_args=(0,))
for i in range(2000):
state, loss = jax.jit(update)(state, obs)
print(get_params(state))
```
"""
svi = SVI(model, guide, optim, loss)
if svi_init_params is None:
init_params = lambda rng_key, *args: None
else:
init_params = svi_init_params
init = svi.init
get_params = svi.get_params
update = svi.update
if n_init > 1:
rng_key = random.split(rng_key, n_init)
map_args = (None,) * len(model_args)
init_params = jax.vmap(init_params, in_axes=(0,) + map_args)
init = jax.vmap(init, in_axes=(0,) + map_args)
get_params = jax.vmap(get_params)
update = jax.vmap(update, in_axes=(0,) + map_args)
if batch_args is not None:
map_args = (batch_args,) * len(model_args) if isinstance(batch_args, int) else batch_args
init_params = jax.vmap(init_params, in_axes=(None,) + map_args)
init = jax.vmap(init, in_axes=(None,) + map_args)
get_params = jax.vmap(get_params)
update = jax.vmap(update, in_axes=(0,) + map_args)
# rng_key is only used to init params through the default route or
# with the init_params arg so it won't actually be used twice here.
state = init(rng_key, *model_args, init_params=jax.jit(init_params)(rng_key, *model_args))
return state, get_params, update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment