Last active
January 3, 2022 21:16
Star
You must be signed in to star a gist
A minimal example for computing log probability with pystan3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import stan\n", | |
"import numpy as np\n", | |
"import nest_asyncio\n", | |
"\n", | |
"nest_asyncio.apply()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"stan_code = \"\"\"\n", | |
"parameters {\n", | |
" real<lower=0> a;\n", | |
" matrix[3,4] B;\n", | |
"}\n", | |
"model {\n", | |
" a ~ normal(0,1);\n", | |
" for (n in 1:3) {\n", | |
" for (m in 1:4) {\n", | |
" B[n,m] ~ normal(0,2);\n", | |
" }\n", | |
" }\n", | |
"}\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Building...\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Building: found in cache, done." | |
] | |
} | |
], | |
"source": [ | |
"model = stan.build(stan_code)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sampling: 0%\n", | |
"Sampling: 20% (1001/5005)\n", | |
"Sampling: 40% (2002/5005)\n", | |
"Sampling: 60% (3003/5005)\n", | |
"Sampling: 80% (4004/5005)\n", | |
"Sampling: 100% (5005/5005)\n", | |
"Sampling: 100% (5005/5005), done.\n", | |
"Messages received during sampling:\n", | |
" Gradient evaluation took 2.1e-05 seconds\n", | |
" 1000 transitions using 10 leapfrog steps per transition would take 0.21 seconds.\n", | |
" Adjust your expectations accordingly!\n", | |
" Gradient evaluation took 1.5e-05 seconds\n", | |
" 1000 transitions using 10 leapfrog steps per transition would take 0.15 seconds.\n", | |
" Adjust your expectations accordingly!\n", | |
" Gradient evaluation took 1.5e-05 seconds\n", | |
" 1000 transitions using 10 leapfrog steps per transition would take 0.15 seconds.\n", | |
" Adjust your expectations accordingly!\n", | |
" Gradient evaluation took 1.3e-05 seconds\n", | |
" 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.\n", | |
" Adjust your expectations accordingly!\n", | |
" Gradient evaluation took 1.3e-05 seconds\n", | |
" 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.\n", | |
" Adjust your expectations accordingly!\n" | |
] | |
} | |
], | |
"source": [ | |
"fit = model.sample(num_chains=5, num_samples=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sample = {key: values for key, values in fit.items()}\n", | |
"lp = fit.get(\"lp__\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[-8.2162622 , -4.09946552, -8.44246307, -5.32088487, -5.51490078]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"a (1, 5)\n", | |
"B (3, 4, 5)\n" | |
] | |
} | |
], | |
"source": [ | |
"for key, values in sample.items():\n", | |
" print(key, values.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'a': array([1.46704541]),\n", | |
" 'B': array([[ 0.3010759 , 2.62175333, 1.5205309 , 2.14743349],\n", | |
" [-2.51929445, -4.67990587, 1.13856508, -0.00859784],\n", | |
" [ 2.09569678, -0.12189783, 3.51391871, -0.0171373 ]])}" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# get one draw\n", | |
"example_dict = {key: values[...,0] for key, values in sample.items()}\n", | |
"example_dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"a (1,)\n", | |
"B (3, 4)\n" | |
] | |
} | |
], | |
"source": [ | |
"for key, values in example_dict.items():\n", | |
" print(key, values.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([], [3, 4])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit.dims" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for key, values in example_dict.items():\n", | |
" example_dict[key] = values.tolist()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0.38325045034117117,\n", | |
" 0.30107590208131263,\n", | |
" -2.5192944523230185,\n", | |
" 2.095696783595601,\n", | |
" 2.621753325495818,\n", | |
" -4.679905867389392,\n", | |
" -0.12189783314871028,\n", | |
" 1.5205308989488182,\n", | |
" 1.138565079115105,\n", | |
" 3.513918714001353,\n", | |
" 2.1474334859675004,\n", | |
" -0.00859783914498971,\n", | |
" -0.01713729830761479]" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"unconstrained = model.unconstrain_pars(example_dict)\n", | |
"unconstrained" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('a',\n", | |
" 'B.1.1',\n", | |
" 'B.2.1',\n", | |
" 'B.3.1',\n", | |
" 'B.1.2',\n", | |
" 'B.2.2',\n", | |
" 'B.3.2',\n", | |
" 'B.1.3',\n", | |
" 'B.2.3',\n", | |
" 'B.3.3',\n", | |
" 'B.1.4',\n", | |
" 'B.2.4',\n", | |
" 'B.3.4')" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.constrained_param_names" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-8.216262201831006" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# default is adjust_transform=True\n", | |
"model.log_prob(unconstrained)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-8.599512652172177" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.log_prob(unconstrained, adjust_transform=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-8.216262201831006" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lp[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.log_prob(unconstrained) == lp[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[-1.1522222235128177,\n", | |
" -0.07526897552032816,\n", | |
" 0.6298236130807546,\n", | |
" -0.5239241958989003,\n", | |
" -0.6554383313739545,\n", | |
" 1.169976466847348,\n", | |
" 0.03047445828717757,\n", | |
" -0.38013272473720455,\n", | |
" -0.28464126977877624,\n", | |
" -0.8784796785003383,\n", | |
" -0.5368583714918751,\n", | |
" 0.0021494597862474277,\n", | |
" 0.0042843245769036975]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.grad_log_prob(unconstrained)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### compare with [pystan2 example](https://gist.github.com/ahartikainen/8713171d259718cf737d8a483500e0c2) by manually constructing an example dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# set the same values\n", | |
"example_dict = {'a': 0.38558730390174756,\n", | |
" 'B': [[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n", | |
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n", | |
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]]}\n", | |
"\n", | |
"unconstrained = model.unconstrain_pars(example_dict)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"-2.340837939218495" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.log_prob(unconstrained)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0.8513224310697813,\n", | |
" -0.0782397575,\n", | |
" 0.07175069,\n", | |
" -0.01970388,\n", | |
" -0.39794108,\n", | |
" -0.00238694,\n", | |
" 0.05042248,\n", | |
" -0.06437852,\n", | |
" -0.2001305825,\n", | |
" -0.27654309,\n", | |
" 0.1355053425,\n", | |
" -0.054699305,\n", | |
" -0.584949595]" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# correspond to adjust_transform=True\n", | |
"model.grad_log_prob(unconstrained)" | |
] | |
} | |
], | |
"metadata": { | |
"interpreter": { | |
"hash": "b039f26390276cbd09cb06142d71cb8300fdbbe84edb1a3cb7e98842b5a1c229" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.8.11 64-bit ('py38_rosetta': conda)", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.11" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment