Last active
June 3, 2024 17:23
-
-
Save banditkings/5c6f48ff160a59c1890e1f3022d32397 to your computer and use it in GitHub Desktop.
Syntax for random sampling from distributions with numpy, scipy, jax, and numpyro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- NUMPY --- | |
import numpy as np | |
np.random.seed(42) | |
np.random.binomial(n=1, p=0.5, size=(10,)) | |
# array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1]) | |
# --- SCIPY --- | |
import numpy as np | |
from scipy import stats | |
# scipy just uses numpy.random to generate random numbers | |
np.random.seed(42) | |
stats.bernoulli.rvs(p=0.5, size=(10,)) | |
# array([0, 1, 1, 1, 0, 0, 0, 1, 1, 1]) | |
# --- JAX --- | |
from jax import random | |
# jax.random | |
key = random.PRNGKey(42) | |
random.bernoulli(key=key, p=0.5, shape=(10,)) | |
#Array([False, True, True, False, False, True, True, True, False, False], dtype=bool) | |
# --- NUMPYRO --- | |
import numpyro.distributions as dist | |
from jax import random | |
# jax.random | |
key = random.PRNGKey(42) | |
# numpyro | |
d = dist.Bernoulli(probs=0.5) | |
d.sample(key=key, sample_shape=(10,)) | |
# Array([0, 1, 1, 0, 0, 1, 1, 1, 0, 0], dtype=int32) | |
# --- NUMPYRO with effect handler --- | |
import numpyro | |
import numpyro.distributions as dist | |
import jax.numpy as jnp | |
with numpyro.handlers.seed(rng_seed=42): | |
p = jnp.full(shape=(10,), fill_value=0.5) | |
x = numpyro.sample('x', dist.Bernoulli(probs=p)) | |
x | |
# Array([1, 0, 1, 0, 0, 0, 1, 0, 0, 1], dtype=int32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment