Skip to content

Instantly share code, notes, and snippets.

@pipme
Last active January 3, 2022 21:16
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save pipme/ad9385954b774c6c016fe0d4d798047e to your computer and use it in GitHub Desktop.
A minimal example for computing log probability with pystan3
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import stan\n",
"import numpy as np\n",
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"stan_code = \"\"\"\n",
"parameters {\n",
" real<lower=0> a;\n",
" matrix[3,4] B;\n",
"}\n",
"model {\n",
" a ~ normal(0,1);\n",
" for (n in 1:3) {\n",
" for (m in 1:4) {\n",
" B[n,m] ~ normal(0,2);\n",
" }\n",
" }\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Building: found in cache, done."
]
}
],
"source": [
"model = stan.build(stan_code)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling: 0%\n",
"Sampling: 20% (1001/5005)\n",
"Sampling: 40% (2002/5005)\n",
"Sampling: 60% (3003/5005)\n",
"Sampling: 80% (4004/5005)\n",
"Sampling: 100% (5005/5005)\n",
"Sampling: 100% (5005/5005), done.\n",
"Messages received during sampling:\n",
" Gradient evaluation took 2.1e-05 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 0.21 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 1.5e-05 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 0.15 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 1.5e-05 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 0.15 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 1.3e-05 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.\n",
" Adjust your expectations accordingly!\n",
" Gradient evaluation took 1.3e-05 seconds\n",
" 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.\n",
" Adjust your expectations accordingly!\n"
]
}
],
"source": [
"fit = model.sample(num_chains=5, num_samples=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sample = {key: values for key, values in fit.items()}\n",
"lp = fit.get(\"lp__\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-8.2162622 , -4.09946552, -8.44246307, -5.32088487, -5.51490078]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lp"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a (1, 5)\n",
"B (3, 4, 5)\n"
]
}
],
"source": [
"for key, values in sample.items():\n",
" print(key, values.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': array([1.46704541]),\n",
" 'B': array([[ 0.3010759 , 2.62175333, 1.5205309 , 2.14743349],\n",
" [-2.51929445, -4.67990587, 1.13856508, -0.00859784],\n",
" [ 2.09569678, -0.12189783, 3.51391871, -0.0171373 ]])}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get one draw\n",
"example_dict = {key: values[...,0] for key, values in sample.items()}\n",
"example_dict"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a (1,)\n",
"B (3, 4)\n"
]
}
],
"source": [
"for key, values in example_dict.items():\n",
" print(key, values.shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([], [3, 4])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.dims"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"for key, values in example_dict.items():\n",
" example_dict[key] = values.tolist()\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.38325045034117117,\n",
" 0.30107590208131263,\n",
" -2.5192944523230185,\n",
" 2.095696783595601,\n",
" 2.621753325495818,\n",
" -4.679905867389392,\n",
" -0.12189783314871028,\n",
" 1.5205308989488182,\n",
" 1.138565079115105,\n",
" 3.513918714001353,\n",
" 2.1474334859675004,\n",
" -0.00859783914498971,\n",
" -0.01713729830761479]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unconstrained = model.unconstrain_pars(example_dict)\n",
"unconstrained"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('a',\n",
" 'B.1.1',\n",
" 'B.2.1',\n",
" 'B.3.1',\n",
" 'B.1.2',\n",
" 'B.2.2',\n",
" 'B.3.2',\n",
" 'B.1.3',\n",
" 'B.2.3',\n",
" 'B.3.3',\n",
" 'B.1.4',\n",
" 'B.2.4',\n",
" 'B.3.4')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.constrained_param_names"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-8.216262201831006"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# default is adjust_transform=True\n",
"model.log_prob(unconstrained)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-8.599512652172177"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.log_prob(unconstrained, adjust_transform=False)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-8.216262201831006"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lp[0,0]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.log_prob(unconstrained) == lp[0,0]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[-1.1522222235128177,\n",
" -0.07526897552032816,\n",
" 0.6298236130807546,\n",
" -0.5239241958989003,\n",
" -0.6554383313739545,\n",
" 1.169976466847348,\n",
" 0.03047445828717757,\n",
" -0.38013272473720455,\n",
" -0.28464126977877624,\n",
" -0.8784796785003383,\n",
" -0.5368583714918751,\n",
" 0.0021494597862474277,\n",
" 0.0042843245769036975]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.grad_log_prob(unconstrained)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### compare with [pystan2 example](https://gist.github.com/ahartikainen/8713171d259718cf737d8a483500e0c2) by manually constructing an example dict"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# set the same values\n",
"example_dict = {'a': 0.38558730390174756,\n",
" 'B': [[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n",
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n",
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]]}\n",
"\n",
"unconstrained = model.unconstrain_pars(example_dict)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-2.340837939218495"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.log_prob(unconstrained)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.8513224310697813,\n",
" -0.0782397575,\n",
" 0.07175069,\n",
" -0.01970388,\n",
" -0.39794108,\n",
" -0.00238694,\n",
" 0.05042248,\n",
" -0.06437852,\n",
" -0.2001305825,\n",
" -0.27654309,\n",
" 0.1355053425,\n",
" -0.054699305,\n",
" -0.584949595]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# correspond to adjust_transform=True\n",
"model.grad_log_prob(unconstrained)"
]
}
],
"metadata": {
"interpreter": {
"hash": "b039f26390276cbd09cb06142d71cb8300fdbbe84edb1a3cb7e98842b5a1c229"
},
"kernelspec": {
"display_name": "Python 3.8.11 64-bit ('py38_rosetta': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment