Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created July 4, 2019 04:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/3827fbaeafc25a920ee98337e0057888 to your computer and use it in GitHub Desktop.
Save fehiepsi/3827fbaeafc25a920ee98337e0057888 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import logging\n",
"\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"J = 8\n",
"y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor)\n",
"sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor)\n",
"\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.poutine as poutine\n",
"from pyro.infer.mcmc import MCMC, NUTS\n",
"\n",
"logging.basicConfig(format='%(message)s', level=logging.INFO)\n",
"pyro.enable_validation(True)\n",
"pyro.set_rng_seed(0)\n",
"\n",
"\n",
"def model(sigma):\n",
" eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))\n",
" mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))\n",
" tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))\n",
"\n",
" theta = mu + tau * eta\n",
"\n",
" return pyro.sample(\"obs\", dist.Normal(theta, sigma))\n",
"\n",
"\n",
"def conditioned_model(model, sigma, y):\n",
" return poutine.condition(model, data={\"obs\": y})(sigma)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sample: 100%|██████████| 101/101 [00:00<00:00, 107.93it/s, step size=1.00e+00, acc. rate=0.000, diverging=5]\n",
" mean std 25% 50% 75%\n",
"mu -7.789111 0.0 -7.789111 -7.789111 -7.789111\n",
"tau 28.963778 0.0 28.963770 28.963770 28.963770\n",
"eta[0] 0.201818 0.0 0.201818 0.201818 0.201818\n",
"eta[1] -0.178471 0.0 -0.178471 -0.178471 -0.178471\n",
"eta[2] -0.088910 0.0 -0.088910 -0.088910 -0.088910\n",
"eta[3] -0.414864 0.0 -0.414864 -0.414864 -0.414864\n",
"eta[4] -1.128324 0.0 -1.128324 -1.128324 -1.128324\n",
"eta[5] 0.959426 0.0 0.959426 0.959426 0.959426\n",
"eta[6] 1.358734 0.0 1.358734 1.358734 1.358734\n",
"eta[7] -0.099521 0.0 -0.099521 -0.099521 -0.099521\n"
]
}
],
"source": [
"nuts_kernel = NUTS(conditioned_model, step_size=1, adapt_step_size=False, jit_compile=True)\n",
"posterior = MCMC(nuts_kernel,\n",
" num_samples=100,\n",
" warmup_steps=1,\n",
" num_chains=1).run(model, sigma, y)\n",
"marginal = posterior.marginal(sites=[\"mu\", \"tau\", \"eta\"])\n",
"marginal = torch.cat(list(marginal.support(flatten=True).values()), dim=-1).cpu().numpy()\n",
"params = ['mu', 'tau', 'eta[0]', 'eta[1]', 'eta[2]', 'eta[3]', 'eta[4]', 'eta[5]', 'eta[6]', 'eta[7]']\n",
"df = pd.DataFrame(marginal, columns=params).transpose()\n",
"df_summary = df.apply(pd.Series.describe, axis=1)[[\"mean\", \"std\", \"25%\", \"50%\", \"75%\"]]\n",
"logging.info(df_summary)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sample: 100%|██████████| 101/101 [00:05<00:00, 18.74it/s, step size=1.00e-01, acc. rate=1.000]\n",
" mean std 25% 50% 75%\n",
"mu 1.748928 3.651412 -0.899974 1.331997 3.412311\n",
"tau 6.333198 4.236098 3.254654 5.521552 8.114382\n",
"eta[0] 0.676266 1.137671 -0.142722 0.653935 1.477357\n",
"eta[1] 0.111057 0.831101 -0.492473 0.179173 0.735119\n",
"eta[2] -0.210698 0.985626 -1.054112 -0.083195 0.460420\n",
"eta[3] 0.287343 0.895648 -0.332588 0.351833 0.928157\n",
"eta[4] -0.185150 0.810296 -0.813105 -0.192511 0.364677\n",
"eta[5] -0.048596 0.983409 -0.681045 -0.099627 0.433513\n",
"eta[6] 0.603709 0.834643 0.053813 0.538607 1.272161\n",
"eta[7] 0.308368 0.841326 -0.170804 0.391864 0.859839\n"
]
}
],
"source": [
"nuts_kernel = NUTS(conditioned_model, step_size=0.1, adapt_step_size=False, jit_compile=True)\n",
"posterior = MCMC(nuts_kernel,\n",
" num_samples=100,\n",
" warmup_steps=1,\n",
" num_chains=1).run(model, sigma, y)\n",
"marginal = posterior.marginal(sites=[\"mu\", \"tau\", \"eta\"])\n",
"marginal = torch.cat(list(marginal.support(flatten=True).values()), dim=-1).cpu().numpy()\n",
"params = ['mu', 'tau', 'eta[0]', 'eta[1]', 'eta[2]', 'eta[3]', 'eta[4]', 'eta[5]', 'eta[6]', 'eta[7]']\n",
"df = pd.DataFrame(marginal, columns=params).transpose()\n",
"df_summary = df.apply(pd.Series.describe, axis=1)[[\"mean\", \"std\", \"25%\", \"50%\", \"75%\"]]\n",
"logging.info(df_summary)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment