Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active April 6, 2019 04:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/4585e45445e6087e446133a435477386 to your computer and use it in GitHub Desktop.
Save fehiepsi/4585e45445e6087e446133a435477386 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/lib/xla_bridge.py:144: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
}
],
"source": [
"import pytest\n",
"from numpy.testing import assert_allclose\n",
"\n",
"import jax.numpy as np\n",
"import jax.random as random\n",
"from jax import jit, lax\n",
"from jax.scipy.special import expit\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions.util import validation_disabled\n",
"from numpyro.handlers import sample\n",
"from numpyro.hmc_util import initialize_model\n",
"from numpyro.mcmc import hmc_kernel\n",
"from numpyro.util import tscan"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"N, dim = 3000, 3\n",
"warmup_steps, num_samples = 1000, 8000\n",
"data = random.normal(random.PRNGKey(0), (N, dim))\n",
"true_coefs = np.arange(1., dim + 1.)\n",
"probs = expit(np.sum(true_coefs * data, axis=-1))\n",
"labels = dist.bernoulli(probs).rvs(random_state=random.PRNGKey(0))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 8.62 s, sys: 35.9 ms, total: 8.66 s\n",
"Wall time: 8.65 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"HMC\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.66 s, sys: 12 ms, total: 7.68 s\n",
"Wall time: 7.67 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"HMC\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = lax.scan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 22.7 s, sys: 71.9 ms, total: 22.7 s\n",
"Wall time: 22.7 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"NUTS\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = lax.scan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 21.8 s, sys: 48 ms, total: 21.9 s\n",
"Wall time: 21.9 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"NUTS\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.68 s, sys: 16 ms, total: 7.7 s\n",
"Wall time: 7.69 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"HMC\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 21.9 s, sys: 36 ms, total: 22 s\n",
"Wall time: 21.9 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"NUTS\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 22.6 s, sys: 28 ms, total: 22.6 s\n",
"Wall time: 22.6 s\n"
]
}
],
"source": [
"%%time\n",
"with validation_disabled():\n",
" algo = \"NUTS\"\n",
" def model(labels):\n",
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n",
" logits = np.sum(coefs * data, axis=-1)\n",
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)\n",
"\n",
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n",
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=algo)\n",
" hmc_state = init_kernel(init_params,\n",
" step_size=0.1,\n",
" num_steps=15,\n",
" num_warmup_steps=warmup_steps)\n",
" sample_kernel = jit(sample_kernel)\n",
" hmc_states = lax.scan(lambda state, i: sample_kernel(state),\n",
" hmc_state, np.arange(num_samples))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment