Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created October 9, 2019 02:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/f03eeaa85b7d00a4d6d8214fea96606d to your computer and use it in GitHub Desktop.
Save fehiepsi/f03eeaa85b7d00a4d6d8214fea96606d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pyro.distributions as dist\n",
"\n",
"NUM_WARMUP = 1000\n",
"NUM_SAMPLES = 1000\n",
"NUM_CHAINS = 3\n",
"N = 2500\n",
"P = 8\n",
"\n",
"alpha_true = dist.Normal(42.0, 10.0).sample()\n",
"beta_true = dist.Normal(torch.zeros(P), 10.0).sample()\n",
"sigma_true = dist.Exponential(1.0).sample()\n",
"\n",
"eps = dist.Normal(0.0, sigma_true).sample([N])\n",
"x = torch.randn(N, P)\n",
"y = alpha_true + x @ beta_true + eps\n",
"\n",
"stan_code = \"\"\"\n",
"data {\n",
" int<lower = 0> N;\n",
" int<lower = 0> P;\n",
" matrix[N, P] x;\n",
" vector[N] y;\n",
"}\n",
"\n",
"parameters {\n",
" real alpha;\n",
" vector[P] beta;\n",
" real<lower = 0.0> sigma;\n",
"}\n",
"\n",
"model {\n",
" alpha ~ normal(0.0, 100.0);\n",
" beta ~ normal(0.0, 10.0);\n",
" sigma ~ normal(0.0, 10.0);\n",
" y ~ normal(alpha + x * beta, sigma);\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7 NOW.\n"
]
}
],
"source": [
"import pystan\n",
"\n",
"stan_data = {\"N\": N, \"P\": P, \"x\": x.numpy(), \"y\": y.numpy()}\n",
"stan_model = pystan.StanModel(model_code=stan_code)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"stan_fit = stan_model.sampling(data=stan_data, iter=2000, warmup=1000, chains=1)\n",
"samples = stan_fit.extract()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inference for Stan model: anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7.\n",
"1 chains, each with iter=2000; warmup=1000; thin=1; \n",
"post-warmup draws per chain=1000, total post-warmup draws=1000.\n",
"\n",
" mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat\n",
"alpha 37.16 2.9e-4 0.01 37.13 37.15 37.16 37.17 37.19 2480 1.0\n",
"beta[1] 3.12 2.4e-4 0.02 3.1 3.11 3.12 3.14 3.15 3903 1.0\n",
"beta[2] 12.61 3.3e-4 0.01 12.59 12.6 12.61 12.62 12.64 1687 1.0\n",
"beta[3] -7.06 3.2e-4 0.02 -7.09 -7.07 -7.06 -7.05 -7.03 2262 1.0\n",
"beta[4] 10.05 2.9e-4 0.01 10.02 10.04 10.05 10.06 10.07 2415 1.0\n",
"beta[5] 3.12 2.9e-4 0.01 3.1 3.11 3.12 3.13 3.15 2488 1.0\n",
"beta[6] -2.93 3.4e-4 0.01 -2.96 -2.94 -2.93 -2.92 -2.91 1818 1.0\n",
"beta[7] -10.83 3.4e-4 0.01 -10.86 -10.84 -10.83 -10.82 -10.8 1902 1.0\n",
"beta[8] 2.28 3.2e-4 0.02 2.25 2.27 2.28 2.29 2.31 2472 1.0\n",
"sigma 0.73 2.6e-4 0.01 0.71 0.72 0.73 0.74 0.76 1679 1.0\n",
"lp__ -472.1 0.11 2.23 -477.2 -473.4 -471.8 -470.5 -468.6 401 1.0\n",
"\n",
"Samples were drawn using NUTS at Tue Oct 8 22:01:11 2019.\n",
"For each parameter, n_eff is a crude measure of effective sample size,\n",
"and Rhat is the potential scale reduction factor on split chains (at \n",
"convergence, Rhat=1).\n"
]
}
],
"source": [
"print(stan_fit.stansummary())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
" alpha 37.16 0.01 37.16 37.14 37.19 936.91 1.00\n",
" beta[0] 3.12 0.02 3.12 3.10 3.15 873.23 1.00\n",
" beta[1] 12.61 0.01 12.61 12.59 12.63 941.15 1.00\n",
" beta[2] -7.06 0.02 -7.06 -7.09 -7.04 1061.48 1.00\n",
" beta[3] 10.05 0.01 10.05 10.02 10.07 747.26 1.00\n",
" beta[4] 3.12 0.01 3.12 3.10 3.15 893.47 1.00\n",
" beta[5] -2.93 0.01 -2.93 -2.96 -2.91 904.31 1.00\n",
" beta[6] -10.83 0.01 -10.83 -10.85 -10.81 913.99 1.00\n",
" beta[7] 2.28 0.02 2.28 2.26 2.31 1117.69 1.00\n",
" lp__ -472.15 2.23 -471.86 -475.28 -468.57 895.91 1.00\n",
" sigma 0.73 0.01 0.73 0.71 0.75 1029.56 1.00\n",
"\n",
"\n"
]
}
],
"source": [
"from numpyro.diagnostics import summary\n",
"summary(dict(samples), group_by_chain=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hpd_3%</th>\n",
" <th>hpd_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>alpha</td>\n",
" <td>37.161</td>\n",
" <td>0.015</td>\n",
" <td>37.130</td>\n",
" <td>37.186</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>919.0</td>\n",
" <td>919.0</td>\n",
" <td>914.0</td>\n",
" <td>952.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[0]</td>\n",
" <td>3.124</td>\n",
" <td>0.015</td>\n",
" <td>3.095</td>\n",
" <td>3.151</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>881.0</td>\n",
" <td>881.0</td>\n",
" <td>888.0</td>\n",
" <td>975.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[1]</td>\n",
" <td>12.614</td>\n",
" <td>0.014</td>\n",
" <td>12.587</td>\n",
" <td>12.639</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>944.0</td>\n",
" <td>944.0</td>\n",
" <td>946.0</td>\n",
" <td>857.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[2]</td>\n",
" <td>-7.063</td>\n",
" <td>0.015</td>\n",
" <td>-7.093</td>\n",
" <td>-7.035</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1066.0</td>\n",
" <td>1066.0</td>\n",
" <td>1071.0</td>\n",
" <td>1026.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[3]</td>\n",
" <td>10.046</td>\n",
" <td>0.014</td>\n",
" <td>10.020</td>\n",
" <td>10.072</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>756.0</td>\n",
" <td>756.0</td>\n",
" <td>762.0</td>\n",
" <td>843.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[4]</td>\n",
" <td>3.124</td>\n",
" <td>0.014</td>\n",
" <td>3.099</td>\n",
" <td>3.151</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>893.0</td>\n",
" <td>893.0</td>\n",
" <td>895.0</td>\n",
" <td>944.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[5]</td>\n",
" <td>-2.934</td>\n",
" <td>0.014</td>\n",
" <td>-2.964</td>\n",
" <td>-2.910</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>930.0</td>\n",
" <td>930.0</td>\n",
" <td>925.0</td>\n",
" <td>876.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[6]</td>\n",
" <td>-10.829</td>\n",
" <td>0.015</td>\n",
" <td>-10.857</td>\n",
" <td>-10.803</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>911.0</td>\n",
" <td>911.0</td>\n",
" <td>904.0</td>\n",
" <td>889.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[7]</td>\n",
" <td>2.281</td>\n",
" <td>0.016</td>\n",
" <td>2.251</td>\n",
" <td>2.309</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1118.0</td>\n",
" <td>1118.0</td>\n",
" <td>1116.0</td>\n",
" <td>749.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>sigma</td>\n",
" <td>0.732</td>\n",
" <td>0.011</td>\n",
" <td>0.712</td>\n",
" <td>0.752</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1040.0</td>\n",
" <td>1040.0</td>\n",
" <td>1029.0</td>\n",
" <td>1023.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>lp__</td>\n",
" <td>-472.146</td>\n",
" <td>2.227</td>\n",
" <td>-476.219</td>\n",
" <td>-468.398</td>\n",
" <td>0.074</td>\n",
" <td>0.052</td>\n",
" <td>914.0</td>\n",
" <td>914.0</td>\n",
" <td>900.0</td>\n",
" <td>914.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hpd_3% hpd_97% mcse_mean mcse_sd ess_mean \\\n",
"alpha 37.161 0.015 37.130 37.186 0.000 0.000 919.0 \n",
"beta[0] 3.124 0.015 3.095 3.151 0.001 0.000 881.0 \n",
"beta[1] 12.614 0.014 12.587 12.639 0.000 0.000 944.0 \n",
"beta[2] -7.063 0.015 -7.093 -7.035 0.000 0.000 1066.0 \n",
"beta[3] 10.046 0.014 10.020 10.072 0.001 0.000 756.0 \n",
"beta[4] 3.124 0.014 3.099 3.151 0.000 0.000 893.0 \n",
"beta[5] -2.934 0.014 -2.964 -2.910 0.000 0.000 930.0 \n",
"beta[6] -10.829 0.015 -10.857 -10.803 0.000 0.000 911.0 \n",
"beta[7] 2.281 0.016 2.251 2.309 0.000 0.000 1118.0 \n",
"sigma 0.732 0.011 0.712 0.752 0.000 0.000 1040.0 \n",
"lp__ -472.146 2.227 -476.219 -468.398 0.074 0.052 914.0 \n",
"\n",
" ess_sd ess_bulk ess_tail r_hat \n",
"alpha 919.0 914.0 952.0 NaN \n",
"beta[0] 881.0 888.0 975.0 NaN \n",
"beta[1] 944.0 946.0 857.0 NaN \n",
"beta[2] 1066.0 1071.0 1026.0 NaN \n",
"beta[3] 756.0 762.0 843.0 NaN \n",
"beta[4] 893.0 895.0 944.0 NaN \n",
"beta[5] 930.0 925.0 876.0 NaN \n",
"beta[6] 911.0 904.0 889.0 NaN \n",
"beta[7] 1118.0 1116.0 749.0 NaN \n",
"sigma 1040.0 1029.0 1023.0 NaN \n",
"lp__ 914.0 900.0 914.0 NaN "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import arviz as az\n",
"az.summary({k: v[None, ...] for k, v in samples.items()})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment