Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created October 9, 2019 02:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/fe5fc248c3bfd0b1337b0b053f1245dc to your computer and use it in GitHub Desktop.
Save fehiepsi/fe5fc248c3bfd0b1337b0b053f1245dc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"N = 2500\n",
"P = 8\n",
"\n",
"np.random.seed(1)\n",
"alpha_true = np.random.normal(42, 10)\n",
"beta_true = np.random.normal(0, 10, size=P)\n",
"sigma_true = np.random.exponential()\n",
"\n",
"eps = np.random.normal(0, sigma_true, size=N)\n",
"x = np.random.normal(size=(N, P))\n",
"y = alpha_true + x @ beta_true + eps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stan_code = \"\"\"\n",
"data {\n",
" int<lower = 0> N;\n",
" int<lower = 0> P;\n",
" matrix[N, P] x;\n",
" vector[N] y;\n",
"}\n",
"\n",
"parameters {\n",
" real alpha;\n",
" vector[P] beta;\n",
" real<lower = 0.0> sigma;\n",
"}\n",
"\n",
"model {\n",
" alpha ~ normal(0.0, 100.0);\n",
" beta ~ normal(0.0, 10.0);\n",
" sigma ~ normal(0.0, 10.0);\n",
" y ~ normal(alpha + x * beta, sigma);\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7 NOW.\n"
]
}
],
"source": [
"import pystan\n",
"\n",
"stan_data = {\"N\": N, \"P\": P, \"x\": x, \"y\": y}\n",
"stan_model = pystan.StanModel(model_code=stan_code)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inference for Stan model: anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7.\n",
"1 chains, each with iter=2000; warmup=1000; thin=1; \n",
"post-warmup draws per chain=1000, total post-warmup draws=1000.\n",
"\n",
" mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat\n",
"alpha 58.24 1.2e-5 5.6e-4 58.24 58.24 58.24 58.24 58.25 2022 1.0\n",
"beta[1] -6.12 1.2e-5 5.4e-4 -6.12 -6.12 -6.12 -6.12 -6.12 1931 1.0\n",
"beta[2] -5.28 1.2e-5 5.5e-4 -5.28 -5.28 -5.28 -5.28 -5.28 2026 1.0\n",
"beta[3] -10.73 1.2e-5 5.6e-4 -10.73 -10.73 -10.73 -10.73 -10.73 2309 1.0\n",
"beta[4] 8.65 1.2e-5 5.4e-4 8.65 8.65 8.65 8.65 8.66 1978 1.0\n",
"beta[5] -23.02 1.4e-5 5.4e-4 -23.02 -23.02 -23.02 -23.02 -23.01 1512 1.0\n",
"beta[6] 17.45 9.6e-6 5.3e-4 17.45 17.45 17.45 17.45 17.45 3047 1.0\n",
"beta[7] -7.61 1.2e-5 5.6e-4 -7.61 -7.61 -7.61 -7.61 -7.61 2099 1.0\n",
"beta[8] 3.19 1.0e-5 5.6e-4 3.19 3.19 3.19 3.19 3.19 3027 1.0\n",
"sigma 0.03 3.7e-5 4.2e-4 0.03 0.03 0.03 0.03 0.03 127 1.0\n",
"lp__ 7742.9 0.12 2.27 7737.8 7741.6 7743.1 7744.6 7746.4 362 1.0\n",
"\n",
"Samples were drawn using NUTS at Tue Oct 8 22:41:21 2019.\n",
"For each parameter, n_eff is a crude measure of effective sample size,\n",
"and Rhat is the potential scale reduction factor on split chains (at \n",
"convergence, Rhat=1).\n"
]
}
],
"source": [
"stan_fit = stan_model.sampling(data=stan_data, iter=2000, warmup=1000, chains=1, seed=4)\n",
"samples = stan_fit.extract()\n",
"print(stan_fit.stansummary())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n",
"WARNING:arviz.stats.stats_utils:Shape validation failed: input_shape: (1, 1000), minimum_shape: (chains=2, draws=4)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hpd_3%</th>\n",
" <th>hpd_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>alpha</td>\n",
" <td>58.244</td>\n",
" <td>0.001</td>\n",
" <td>58.243</td>\n",
" <td>58.245</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>977.0</td>\n",
" <td>977.0</td>\n",
" <td>977.0</td>\n",
" <td>830.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[0]</td>\n",
" <td>-6.118</td>\n",
" <td>0.001</td>\n",
" <td>-6.119</td>\n",
" <td>-6.117</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>784.0</td>\n",
" <td>784.0</td>\n",
" <td>786.0</td>\n",
" <td>947.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[1]</td>\n",
" <td>-5.282</td>\n",
" <td>0.001</td>\n",
" <td>-5.283</td>\n",
" <td>-5.281</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>939.0</td>\n",
" <td>939.0</td>\n",
" <td>942.0</td>\n",
" <td>901.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[2]</td>\n",
" <td>-10.730</td>\n",
" <td>0.001</td>\n",
" <td>-10.731</td>\n",
" <td>-10.729</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>850.0</td>\n",
" <td>850.0</td>\n",
" <td>851.0</td>\n",
" <td>809.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[3]</td>\n",
" <td>8.654</td>\n",
" <td>0.001</td>\n",
" <td>8.653</td>\n",
" <td>8.655</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1010.0</td>\n",
" <td>1010.0</td>\n",
" <td>1015.0</td>\n",
" <td>988.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[4]</td>\n",
" <td>-23.016</td>\n",
" <td>0.001</td>\n",
" <td>-23.017</td>\n",
" <td>-23.015</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>966.0</td>\n",
" <td>966.0</td>\n",
" <td>969.0</td>\n",
" <td>1021.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[5]</td>\n",
" <td>17.448</td>\n",
" <td>0.001</td>\n",
" <td>17.448</td>\n",
" <td>17.450</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1064.0</td>\n",
" <td>1064.0</td>\n",
" <td>1062.0</td>\n",
" <td>1036.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[6]</td>\n",
" <td>-7.612</td>\n",
" <td>0.001</td>\n",
" <td>-7.613</td>\n",
" <td>-7.611</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>901.0</td>\n",
" <td>901.0</td>\n",
" <td>894.0</td>\n",
" <td>907.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>beta[7]</td>\n",
" <td>3.191</td>\n",
" <td>0.001</td>\n",
" <td>3.190</td>\n",
" <td>3.192</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>946.0</td>\n",
" <td>946.0</td>\n",
" <td>944.0</td>\n",
" <td>807.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>sigma</td>\n",
" <td>0.027</td>\n",
" <td>0.000</td>\n",
" <td>0.027</td>\n",
" <td>0.028</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>1045.0</td>\n",
" <td>1045.0</td>\n",
" <td>1043.0</td>\n",
" <td>771.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <td>lp__</td>\n",
" <td>7742.931</td>\n",
" <td>2.270</td>\n",
" <td>7739.078</td>\n",
" <td>7746.991</td>\n",
" <td>0.074</td>\n",
" <td>0.053</td>\n",
" <td>935.0</td>\n",
" <td>935.0</td>\n",
" <td>912.0</td>\n",
" <td>904.0</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hpd_3% hpd_97% mcse_mean mcse_sd ess_mean \\\n",
"alpha 58.244 0.001 58.243 58.245 0.000 0.000 977.0 \n",
"beta[0] -6.118 0.001 -6.119 -6.117 0.000 0.000 784.0 \n",
"beta[1] -5.282 0.001 -5.283 -5.281 0.000 0.000 939.0 \n",
"beta[2] -10.730 0.001 -10.731 -10.729 0.000 0.000 850.0 \n",
"beta[3] 8.654 0.001 8.653 8.655 0.000 0.000 1010.0 \n",
"beta[4] -23.016 0.001 -23.017 -23.015 0.000 0.000 966.0 \n",
"beta[5] 17.448 0.001 17.448 17.450 0.000 0.000 1064.0 \n",
"beta[6] -7.612 0.001 -7.613 -7.611 0.000 0.000 901.0 \n",
"beta[7] 3.191 0.001 3.190 3.192 0.000 0.000 946.0 \n",
"sigma 0.027 0.000 0.027 0.028 0.000 0.000 1045.0 \n",
"lp__ 7742.931 2.270 7739.078 7746.991 0.074 0.053 935.0 \n",
"\n",
" ess_sd ess_bulk ess_tail r_hat \n",
"alpha 977.0 977.0 830.0 NaN \n",
"beta[0] 784.0 786.0 947.0 NaN \n",
"beta[1] 939.0 942.0 901.0 NaN \n",
"beta[2] 850.0 851.0 809.0 NaN \n",
"beta[3] 1010.0 1015.0 988.0 NaN \n",
"beta[4] 966.0 969.0 1021.0 NaN \n",
"beta[5] 1064.0 1062.0 1036.0 NaN \n",
"beta[6] 901.0 894.0 907.0 NaN \n",
"beta[7] 946.0 944.0 807.0 NaN \n",
"sigma 1045.0 1043.0 771.0 NaN \n",
"lp__ 935.0 912.0 904.0 NaN "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import arviz as az\n",
"az.summary({k: v[None, ...] for k, v in samples.items()})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment