Created
March 13, 2024 15:46
-
-
Save amifalk/eb377a243b046105dc00beda79441b22 to your computer and use it in GitHub Desktop.
Batch SVI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.random as random | |
from numpyro.infer import SVI | |
def batch_svi_setup(model, guide, optim, loss, | |
model_args, | |
*, | |
rng_key, | |
svi_init_params=None, | |
n_init=1, | |
batch_args=None): | |
""" | |
Set up SVI for batched inference over model arguments and/or different parameter initializations. | |
To minimize the complexity of this implementation, the model and guide may not have keyword arguments. | |
Args: | |
model: Python callable with Pyro primitives for the model. | |
guide: Python callable with Pyro primitives for the guide. | |
optim: An instance of :class:`~numpyro.optim._NumpyroOptim`, a ``jax.example_libraries.optimizers.Optimizer`` | |
or an Optax ``GradientTransformation``. If you pass an Optax optimizer it will automatically be wrapped | |
using :func:`numpyro.optim.optax_to_numpyro`. | |
loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize. | |
model_args: tuple of arguments for the model/guide. | |
rng_key (jax.random.PRNGKey): random number generator seed to be used for param initialization. | |
svi_init_params (Callable, optional): Optional callable that initializes the parameters for guide sites. | |
This is necessary to get different initial parameters for multiple initializations. If specified, the function | |
signature must match `fn(rng_key, *args)` and must return a dictionary mapping site names in the guide | |
to initial values. Defaults to None. | |
n_init (int, optional): The number of batched parameter initializations to initialize SVI with. Defaults to 0. | |
batch_args (int | tuple, optional): Optional int/tuple specifying the `in_axes` of model_args to vmap over. | |
If None, will not batch over model_args. Defaults to None. | |
Returns: | |
tuple of (batched initial `SVIState`, batched_get_params_fn, batched_update_fn) | |
**Example** | |
```python | |
import jax | |
import jax.numpy as jnp | |
import numpyro | |
import numpyro.distributions as dist | |
from numpyro.infer import Trace_ELBO | |
from numpyro.infer.autoguide import AutoDelta | |
def model(obs): | |
a = numpyro.sample('mu', dist.Normal(0, 1)) | |
with numpyro.plate('n_obs', 100): | |
numpyro.sample('obs', dist.Normal(a, 1), obs=obs) | |
obs = dist.Normal(loc=jnp.array([3., 6., 9.]), scale=1).sample(jax.random.PRNGKey(0), (100,)).T | |
guide = AutoDelta(model) | |
optim = numpyro.optim.Adam(step_size=.01) | |
def init_params(rng_key, *args): | |
return {'mu_auto_loc': dist.Normal(0, 1).sample(rng_key)} | |
state, get_params, update = batch_svi_setup(model, guide, optim, Trace_ELBO(), | |
model_args=(obs,), | |
rng_key=jax.random.PRNGKey(1), | |
svi_init_params=init_params, | |
n_init=2, | |
batch_args=(0,)) | |
for i in range(2000): | |
state, loss = jax.jit(update)(state, obs) | |
print(get_params(state)) | |
``` | |
""" | |
svi = SVI(model, guide, optim, loss) | |
if svi_init_params is None: | |
init_params = lambda rng_key, *args: None | |
else: | |
init_params = svi_init_params | |
init = svi.init | |
get_params = svi.get_params | |
update = svi.update | |
if n_init > 1: | |
rng_key = random.split(rng_key, n_init) | |
map_args = (None,) * len(model_args) | |
init_params = jax.vmap(init_params, in_axes=(0,) + map_args) | |
init = jax.vmap(init, in_axes=(0,) + map_args) | |
get_params = jax.vmap(get_params) | |
update = jax.vmap(update, in_axes=(0,) + map_args) | |
if batch_args is not None: | |
map_args = (batch_args,) * len(model_args) if isinstance(batch_args, int) else batch_args | |
init_params = jax.vmap(init_params, in_axes=(None,) + map_args) | |
init = jax.vmap(init, in_axes=(None,) + map_args) | |
get_params = jax.vmap(get_params) | |
update = jax.vmap(update, in_axes=(0,) + map_args) | |
# rng_key is only used to init params through the default route or | |
# with the init_params arg so it won't actually be used twice here. | |
state = init(rng_key, *model_args, init_params=jax.jit(init_params)(rng_key, *model_args)) | |
return state, get_params, update |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment