Skip to content

Instantly share code, notes, and snippets.

@amifalk
amifalk / batch_svi.py
Created March 13, 2024 15:46
Batch SVI
import jax
import jax.random as random
from numpyro.infer import SVI
def batch_svi_setup(model, guide, optim, loss,
model_args,
*,
rng_key,
svi_init_params=None,
n_init=1,
@amifalk
amifalk / mcmc_lite.py
Last active March 9, 2024 00:38
Functional NumPyro MCMC Wrapper
from functools import partial
from operator import attrgetter
import jax
import jax.numpy as jnp
import jax.random as random
@partial(jax.jit, static_argnames=['field_names'])
def collect_fields(state, field_names: tuple):
"""