Skip to content

Instantly share code, notes, and snippets.

@ahartikainen
Last active October 16, 2020 19:28
Show Gist options
  • Save ahartikainen/8713171d259718cf737d8a483500e0c2 to your computer and use it in GitHub Desktop.
Save ahartikainen/8713171d259718cf737d8a483500e0c2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pystan\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"stan_code = \"\"\"\n",
"parameters {\n",
" real<lower=0> a;\n",
" matrix[3,4] B;\n",
"}\n",
"model {\n",
" a ~ normal(0,1);\n",
" for (n in 1:3) {\n",
" for (m in 1:4) {\n",
" B[n,m] ~ normal(0,2);\n",
" }\n",
" }\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_a71ba528c20fc622bc4c49e3064eafab NOW.\n"
]
}
],
"source": [
"model = pystan.StanModel(model_code=stan_code)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"fit = model.sampling(iter=1, warmup=0, init=0, seed=1, control={\"adapt_engaged\": False}, check_hmc_diagnostics=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sample = fit.extract(permuted=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lp = sample[\"lp__\"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"sample = {key: values for key, values in sample.items() if not key.endswith(\"__\")}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': array([0.3855873 , 0.98054701, 0.49286373, 1.11726411]),\n",
" 'B': array([[[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n",
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n",
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]],\n",
" \n",
" [[ 4.1422781 , 3.68825208, -0.17252071, -1.68395211],\n",
" [-2.86501378, 0.33404798, -1.44207571, -1.80443805],\n",
" [-1.10181232, 0.2255809 , -0.37776586, 0.5961288 ]],\n",
" \n",
" [[-0.22384381, 0.3285261 , 0.27917157, 2.15246902],\n",
" [ 1.5110607 , -1.25719988, 0.80260026, -0.40612884],\n",
" [ 1.11420704, 0.18069843, -1.24951964, -0.30942338]],\n",
" \n",
" [[ 0.27401254, -1.66957992, -0.26767151, -0.44180366],\n",
" [ 0.25835625, -0.16191208, 0.95659342, 1.78234786],\n",
" [ 0.26586376, 0.91323483, 0.18991228, 0.13852963]]])}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-2.34083794, -6.63107518, -2.38813117, -1.54752603])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lp"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a (4,)\n",
"B (4, 3, 4)\n"
]
}
],
"source": [
"for key, values in sample.items():\n",
" print(key, values.shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': 0.38558730390174756,\n",
" 'B': array([[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n",
" [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n",
" [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]])}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get one draw\n",
"example_dict = {key: values[0] for key, values in sample.items()}\n",
"example_dict"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a ()\n",
"B (3, 4)\n"
]
}
],
"source": [
"for key, values in example_dict.items():\n",
" print(key, values.shape)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[], [3, 4]]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.par_dims"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.95298764, 0.31295903, -0.28700276, 0.07881552, 1.59176432,\n",
" 0.00954776, -0.20168992, 0.25751408, 0.80052233, 1.10617236,\n",
" -0.54202137, 0.21879722, 2.33979838])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unconstrained = fit.unconstrain_pars(example_dict)\n",
"unconstrained"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['a',\n",
" 'B.1.1',\n",
" 'B.2.1',\n",
" 'B.3.1',\n",
" 'B.1.2',\n",
" 'B.2.2',\n",
" 'B.3.2',\n",
" 'B.1.3',\n",
" 'B.2.3',\n",
" 'B.3.3',\n",
" 'B.1.4',\n",
" 'B.2.4',\n",
" 'B.3.4']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# this is the order expected, but unconstrain_pars handles that\n",
"fit.unconstrained_param_names()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Calculate log_prob"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-2.3408379411644353"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.log_prob(unconstrained)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-2.3408379411644353"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lp[0]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.log_prob(unconstrained) == lp[0]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.14867757, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n",
" -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n",
" 0.13550534, -0.0546993 , -0.5849496 ])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.grad_log_prob(unconstrained, adjust_transform=False)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.85132243, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n",
" -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n",
" 0.13550534, -0.0546993 , -0.5849496 ])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# the Jacobian adjustment\n",
"fit.grad_log_prob(unconstrained, adjust_transform=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment