Skip to content

Instantly share code, notes, and snippets.

@banditkings
Last active June 3, 2024 17:23
Show Gist options
  • Save banditkings/5c6f48ff160a59c1890e1f3022d32397 to your computer and use it in GitHub Desktop.
Save banditkings/5c6f48ff160a59c1890e1f3022d32397 to your computer and use it in GitHub Desktop.
Syntax for random sampling from distributions with numpy, scipy, jax, and numpyro
# --- NUMPY ---
import numpy as np
np.random.seed(42)
np.random.binomial(n=1, p=0.5, size=(10,))
# array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
# --- SCIPY ---
import numpy as np
from scipy import stats
# scipy just uses numpy.random to generate random numbers
np.random.seed(42)
stats.bernoulli.rvs(p=0.5, size=(10,))
# array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
# --- JAX ---
from jax import random
# jax.random
key = random.PRNGKey(42)
random.bernoulli(key=key, p=0.5, shape=(10,))
#Array([False, True, True, False, False, True, True, True, False, False], dtype=bool)
# --- NUMPYRO ---
import numpyro.distributions as dist
from jax import random
# jax.random
key = random.PRNGKey(42)
# numpyro
d = dist.Bernoulli(probs=0.5)
d.sample(key=key, sample_shape=(10,))
# Array([0, 1, 1, 0, 0, 1, 1, 1, 0, 0], dtype=int32)
# --- NUMPYRO with effect handler ---
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
with numpyro.handlers.seed(rng_seed=42):
p = jnp.full(shape=(10,), fill_value=0.5)
x = numpyro.sample('x', dist.Bernoulli(probs=p))
x
# Array([1, 0, 1, 0, 0, 0, 1, 0, 0, 1], dtype=int32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment