Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active February 20, 2024 21:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/c35bfb2a09add7b2e5745ca7b424a3b1 to your computer and use it in GitHub Desktop.
Save ahwillia/c35bfb2a09add7b2e5745ca7b424a3b1 to your computer and use it in GitHub Desktop.
Elliptical Slice Sampler in JAX
"""
NOTE: This code has not been rigorously tested.
"""
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from tqdm import trange
def elliptical_slice_update(x, log_density, sigmas, key):
"""
Runs one update of Elliptical slice sampling with
respect to a mean-zero, diagonal Gaussian prior.
Parameters
----------
x : array-like
Current state of the MCMC chain, corresponds
to parameters we are sampling.
log_density : Callable
Function that computes log density (up to an
additive constant) that we'd like to sample from
after multiplication with the prior.
sigmas : array-like
Standard deviations for each element of x. This
specifies a prior distribution over x, which is
mean-zero and diagonal covariance with
elements given by (sigmas ** 2)
key : jax.random.PRNGKey
Random number seed, used to generate proposal.
Returns
-------
x_next : array-like
Next state of the MCMC chain.
"""
assert x.shape == sigmas.shape
k1, k2, k3, k4 = jax.random.split(key, num=4)
nu = jax.random.normal(k1, shape=x.shape) * sigmas
thres = log_density(x) + jnp.log(jax.random.uniform(k2))
# Initial proposal
theta = jax.random.uniform(k3, minval=0., maxval=(2 * jnp.pi))
init_loop_state = (
x * jnp.cos(theta) + nu * jnp.sin(theta),
theta - 2 * jnp.pi,
theta,
theta,
k4
)
def while_cond_fun(loop_state):
x_proposed, _, _, _, _ = loop_state
return log_density(x_proposed) <= thres
def true_func(loop_state):
_, _, theta_max, theta, _ = loop_state
return theta, theta_max
def false_func(loop_state):
_, theta_min, _, theta, _ = loop_state
return theta_min, theta
def while_body_fun(loop_state):
_, _, _, theta0, key0 = loop_state
# Reduce brackets and draw a new theta.
theta_min1, theta_max1 = jax.lax.cond(
theta0 < 0,
true_func,
false_func,
loop_state,
)
theta1 = jax.random.uniform(
key0,
minval=theta_min1,
maxval=theta_max1
)
# Propose new value for x.
x_proposed = x * jnp.cos(theta1) + nu * jnp.sin(theta1)
# Update random key for next iteration.
key1 = jax.random.split(key0)[0]
return x_proposed, theta_min1, theta_max1, theta1, key1
final_loop_state = jax.lax.while_loop(
while_cond_fun, while_body_fun, init_loop_state
)
return final_loop_state[0]
if __name__ == "__main__":
x = jnp.zeros(2)
x_samples = [x]
key = jax.random.PRNGKey(111)
num_samples = 1000
def log_density(u):
return jnp.log(jnp.max(jnp.abs(u)) < 1)
for i in trange(num_samples):
_, key = jax.random.split(key)
x = elliptical_slice_update(
x,
log_density,
jnp.array([0.25, 4.0]),
key
)
x_samples.append(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment