Skip to content

Instantly share code, notes, and snippets.

@ahartikainen
Created May 29, 2018 14:03
Show Gist options
  • Save ahartikainen/b16704eec3a912ccd3bb39d62ca04279 to your computer and use it in GitHub Desktop.
Save ahartikainen/b16704eec3a912ccd3bb39d62ca04279 to your computer and use it in GitHub Desktop.
Export PyStan fit object to xarray
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import pystan\n",
"import numpy as np\n",
"import pandas as pd\n",
"import re\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.17.1.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pystan.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"schools_code = \"\"\"\n",
"data {\n",
" int<lower=0> J;\n",
" real y[J];\n",
" real<lower=0> sigma[J];\n",
"}\n",
"\n",
"parameters {\n",
" real mu;\n",
" real<lower=0> tau;\n",
" real theta_tilde[J];\n",
"}\n",
"\n",
"transformed parameters {\n",
" real theta[J];\n",
" for (j in 1:J)\n",
" theta[j] = mu + tau * theta_tilde[j];\n",
"}\n",
"\n",
"model {\n",
" mu ~ normal(0, 5);\n",
" tau ~ cauchy(0, 5);\n",
" theta_tilde ~ normal(0, 1);\n",
" y ~ normal(theta, sigma);\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Data of the Eight Schools Model\n",
"schools_dat = dict(\n",
" J = 8,\n",
" y = np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n",
" sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.]),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 40 ms\n"
]
}
],
"source": [
"%%time\n",
"path = \"./eight_schools_nc_fit.pickle\"\n",
"if os.path.exists(path):\n",
" with open(path, \"rb\") as f:\n",
" model, fit = pickle.load(f)\n",
"else:\n",
" model = pystan.StanModel(model_code=schools_code)\n",
" fit = model.sampling(data=schools_dat, iter=1000, chains=1)\n",
" with open(path, \"wb\") as f:\n",
" pickle.dump((model, fit), f, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Inference for Stan model: anon_model_3f8f9e8bb354ab461436bb51d935571d.\n",
"4 chains, each with iter=1000; warmup=500; thin=1; \n",
"post-warmup draws per chain=500, total post-warmup draws=2000.\n",
"\n",
" mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat\n",
"mu 4.32 0.07 3.3 -2.13 2.09 4.32 6.43 10.92 2227 1.0\n",
"tau 3.6 0.09 3.15 0.13 1.29 2.81 5.05 11.69 1139 1.0\n",
"theta_tilde[1] 0.32 0.02 0.99 -1.73 -0.34 0.32 1.04 2.18 1865 1.0\n",
"theta_tilde[2] 0.1 0.02 0.91 -1.67 -0.49 0.12 0.7 1.89 2326 1.0\n",
"theta_tilde[3] -0.06 0.02 0.95 -1.96 -0.72 -0.06 0.59 1.84 2522 1.0\n",
"theta_tilde[4] 0.07 0.02 0.93 -1.7 -0.55 0.06 0.72 1.88 1720 1.0\n",
"theta_tilde[5] -0.17 0.02 0.93 -2.02 -0.77 -0.18 0.41 1.66 2125 1.0\n",
"theta_tilde[6] -0.1 0.02 0.95 -1.97 -0.76 -0.1 0.54 1.72 2489 1.0\n",
"theta_tilde[7] 0.35 0.02 0.96 -1.5 -0.3 0.35 1.0 2.2 1741 1.0\n",
"theta_tilde[8] 0.07 0.02 0.97 -1.83 -0.6 0.09 0.71 1.94 2447 1.0\n",
"theta[1] 6.18 0.15 5.61 -3.12 2.53 5.56 9.02 18.92 1430 1.0\n",
"theta[2] 4.78 0.1 4.64 -4.06 2.0 4.76 7.46 14.2 2019 1.0\n",
"theta[3] 3.83 0.12 5.27 -7.82 1.16 4.09 7.0 13.6 1998 1.0\n",
"theta[4] 4.72 0.11 4.81 -4.65 1.73 4.44 7.73 14.35 1772 1.0\n",
"theta[5] 3.47 0.1 4.78 -6.8 0.91 3.78 6.4 12.2 2288 1.0\n",
"theta[6] 3.8 0.1 4.81 -6.68 0.95 3.94 6.89 12.53 2377 1.0\n",
"theta[7] 6.25 0.13 5.2 -2.89 2.78 5.77 9.11 17.59 1666 1.0\n",
"theta[8] 4.73 0.11 5.1 -5.57 1.69 4.76 7.49 15.3 2183 1.0\n",
"lp__ -6.87 0.09 2.27 -12.09 -8.26 -6.58 -5.23 -3.3 705 1.0\n",
"\n",
"Samples were drawn using NUTS at Tue May 29 16:36:52 2018.\n",
"For each parameter, n_eff is a crude measure of effective sample size,\n",
"and Rhat is the potential scale reduction factor on split chains (at \n",
"convergence, Rhat=1)."
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# xarray"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"from xarray import DataArray, Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# extract to xarray"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['samples', 'chains', 'iter', 'warmup', 'thin', 'n_save', 'warmup2', 'permutation', 'pars_oi', 'dims_oi', 'fnames_oi', 'n_flatnames'])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.sim.keys()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'algorithm': 'NUTS',\n",
" 'append_samples': False,\n",
" 'chain_id': 0,\n",
" 'ctrl': {'sampling': {'adapt_delta': 0.8,\n",
" 'adapt_engaged': True,\n",
" 'adapt_gamma': 0.05,\n",
" 'adapt_init_buffer': 75,\n",
" 'adapt_kappa': 0.75,\n",
" 'adapt_t0': 10.0,\n",
" 'adapt_term_buffer': 50,\n",
" 'adapt_window': 25,\n",
" 'algorithm': <sampling_algo_t.NUTS: 1>,\n",
" 'iter': 1000,\n",
" 'iter_save': 1000,\n",
" 'iter_save_wo_warmup': 500,\n",
" 'max_treedepth': 10,\n",
" 'metric': <sampling_metric_t.DIAG_E: 2>,\n",
" 'refresh': 100,\n",
" 'save_warmup': True,\n",
" 'stepsize': 1.0,\n",
" 'stepsize_jitter': 0.0,\n",
" 'thin': 1,\n",
" 'warmup': 500}},\n",
" 'diagnostic_file': b'',\n",
" 'diagnostic_file_flag': False,\n",
" 'enable_random_init': True,\n",
" 'init': b'random',\n",
" 'init_radius': 2.0,\n",
" 'iter': 1000,\n",
" 'method': <stan_args_method_t.SAMPLING: 1>,\n",
" 'random_seed': 196103384,\n",
" 'sample_file': b'',\n",
" 'sample_file_flag': False,\n",
" 'seed': 196103384,\n",
" 'thin': 1,\n",
" 'warmup': 500}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.stan_args[0] # one for each chain"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def pars_to_xarray(fit, pars=None, infer_dtypes=True):\n",
" # TODO: check that fit has samples\n",
" if pars is None:\n",
" pars = fit.model_pars\n",
" \n",
" if infer_dtypes:\n",
" pattern = r\"int(?:\\[.*\\])*\\s*(.)(?:\\s*[=;]|(?:\\s*<-))\"\n",
" generated_quantities = fit.get_stancode().split(\"generated quantities\")[-1]\n",
" dtypes = re.findall(pattern, generated_quantities)\n",
" dtypes = {item : 'int' for item in dtypes if item in pars}\n",
" else:\n",
" dtypes = dict()\n",
" \n",
" warmup = fit.sim['warmup']\n",
" \n",
" data_vars = {}\n",
" data_vars_warmup = {}\n",
" \n",
" chains = fit.sim['chains']\n",
" # TODO: ADD COORDINATES\n",
" \n",
" for i, (key, values) in enumerate(fit.extract(pars, dtypes=dtypes, permuted=False, inc_warmup=True).items()):\n",
" if chains == 1:\n",
" values = np.expand_dims(values, axis=1)\n",
" if len(values.shape) == 2:\n",
" dims = ('draw', 'chain')\n",
" else:\n",
" dims = ('draw', 'chain', *[\"{key}_axis{j}\".format(key=key, j=j) for j in range(1, len(values.shape)-1)])\n",
"\n",
" data_vars[key] = DataArray(data=values[warmup:],\n",
" dims=dims, \n",
" name=key)\n",
"\n",
" data_vars_warmup[key] = DataArray(data=values[:warmup],\n",
" dims=dims, \n",
" name=\"{key}_warmup\".format(key=key))\n",
" \n",
" data_set = Dataset(data_vars=data_vars, attrs={'warmup' : False})\n",
" data_set_warmup = Dataset(data_vars=data_vars_warmup, attrs={'warmup' : True})\n",
" return data_set, data_set_warmup"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data_set, data_set_warmup = pars_to_xarray(fit)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, draw: 500, theta_axis1: 8, theta_tilde_axis1: 8)\n",
"Dimensions without coordinates: chain, draw, theta_axis1, theta_tilde_axis1\n",
"Data variables:\n",
" mu (draw, chain) float64 5.897 4.075 3.758 3.18 5.637 4.397 ...\n",
" tau (draw, chain) float64 4.811 12.44 0.6475 0.9743 6.566 16.24 ...\n",
" theta_tilde (draw, chain, theta_tilde_axis1) float64 2.403 0.775 ...\n",
" theta (draw, chain, theta_axis1) float64 17.46 9.626 1.175 9.645 ...\n",
"Attributes:\n",
" warmup: False"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_set"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# sampler params"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def sampler_params_to_xarray(fit, params=None):\n",
" if params is None:\n",
" params = list(fit.get_sampler_params()[0].keys()) + ['lp__']\n",
" \n",
" warmup = fit.sim['warmup']\n",
" sampler_params_vars = {}\n",
" sampler_params_vars_warmup = {}\n",
" for chain, sparams in enumerate(fit.get_sampler_params(), 1):\n",
" for key, param in sparams.items():\n",
" if key not in params:\n",
" continue\n",
" if key in ('treedepth__', 'n_leapfrog__'):\n",
" param = param.astype(int)\n",
" elif key == 'divergent__':\n",
" param = param.astype(bool)\n",
" if key not in sampler_params_vars:\n",
" sampler_params_vars[key] = []\n",
" if key not in sampler_params_vars_warmup:\n",
" sampler_params_vars_warmup[key] = []\n",
"\n",
" sampler_params_vars[key].append(DataArray(data=param[warmup:, None],\n",
" dims=('draw', 'chain'), \n",
" name=key))\n",
"\n",
" sampler_params_vars_warmup[key].append(DataArray(data=param[:warmup, None],\n",
" dims=('draw', 'chain'), \n",
" name=key))\n",
" \n",
" for par, values in sampler_params_vars.items():\n",
" sampler_params_vars[par] = xr.concat(values, dim='chain')\n",
"\n",
" for par, values in sampler_params_vars_warmup.items():\n",
" sampler_params_vars_warmup[par] = xr.concat(values, dim='chain')\n",
"\n",
" chains = fit.sim['chains']\n",
" if 'lp__' in params:\n",
" lp = fit.extract(pars='lp__', permuted=False, inc_warmup=True)\n",
" for i, (key, values) in enumerate(lp.items()):\n",
" if chains == 1:\n",
" values = np.expand_dims(values, axis=1)\n",
" sampler_params_vars[key] = DataArray(values[warmup:], dims=('draw', 'chain'))\n",
" sampler_params_vars_warmup[key] = DataArray(values[:warmup], dims=('draw', 'chain'))\n",
"\n",
" sampler_params_dataset = Dataset(data_vars=sampler_params_vars, attrs={'warmup' : False})\n",
" sampler_params_dataset_warmup = Dataset(data_vars=sampler_params_vars_warmup, attrs={'warmup' : True})\n",
" \n",
" return sampler_params_dataset, sampler_params_dataset_warmup"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"sampler_params_dataset, sampler_params_dataset_warmup = sampler_params_to_xarray(fit)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, draw: 500)\n",
"Dimensions without coordinates: chain, draw\n",
"Data variables:\n",
" accept_stat__ (draw, chain) float64 0.8442 0.9884 0.9788 0.8205 0.9857 ...\n",
" stepsize__ (draw, chain) float64 0.4289 0.4329 0.4037 0.4457 0.4289 ...\n",
" treedepth__ (draw, chain) int32 3 3 3 4 3 3 3 3 2 3 3 3 3 3 3 3 3 3 3 ...\n",
" n_leapfrog__ (draw, chain) int32 7 15 15 31 7 7 15 7 3 7 7 7 7 15 7 7 ...\n",
" divergent__ (draw, chain) bool False False False False False False ...\n",
" energy__ (draw, chain) float64 9.991 6.031 11.18 11.77 9.151 8.522 ...\n",
" lp__ (draw, chain) float64 -7.054 -2.283 -7.953 -8.994 -6.556 ...\n",
"Attributes:\n",
" warmup: False"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sampler_params_dataset"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"datetime.datetime(2018, 5, 29, 16, 36, 52, 86000)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"date = fit.date\n",
"date"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# inits"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def inits_to_xarray(fit, pars=None, infer_dtypes=True):\n",
" if pars is None:\n",
" pars = fit.model_pars\n",
" \n",
" if infer_dtypes:\n",
" pattern = r\"int(?:\\[.*\\])*\\s*(.)(?:\\s*[=;]|(?:\\s*<-))\"\n",
" generated_quantities = fit.get_stancode().split(\"generated quantities\")[-1]\n",
" dtypes = re.findall(pattern, generated_quantities)\n",
" dtypes = {item : 'int' for item in dtypes if item in pars}\n",
" else:\n",
" dtypes = dict()\n",
" \n",
" chains = fit.sim['chains']\n",
" inits = {}\n",
" \n",
" for chain, init in enumerate(fit.inits, 1):\n",
" for key, values in init.items():\n",
" if key in dtypes:\n",
" values = values.astype(dtypes[key])\n",
" if key not in inits:\n",
" inits[key] = []\n",
" if len(values.shape) < 1:\n",
" values = np.expand_dims(values, -1)\n",
" else:\n",
" values = np.expand_dims(values, 0)\n",
" values = np.expand_dims(values, 1)\n",
" if len(values.shape) == 2:\n",
" dims = ('draw', 'chain')\n",
" else:\n",
" dims = ('draw', 'chain', *[\"{key}_axis{j}\".format(key=key, j=j) for j in range(1, len(values.shape)-1)])\n",
" inits[key].append(DataArray(data=values,\n",
" dims=dims, \n",
" name=key))\n",
" for par, values in inits.items():\n",
" inits[par] = xr.concat(values, dim='chain')\n",
"\n",
" inits = Dataset(data_vars=inits)\n",
" return inits"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, draw: 1, theta_axis1: 8, theta_tilde_axis1: 8)\n",
"Dimensions without coordinates: chain, draw, theta_axis1, theta_tilde_axis1\n",
"Data variables:\n",
" mu (draw, chain) float64 -1.655 -1.773 -1.097 1.494\n",
" tau (draw, chain) float64 0.6139 1.401 0.3722 4.794\n",
" theta_tilde (draw, chain, theta_tilde_axis1) float64 -1.323 -0.9794 ...\n",
" theta (draw, chain, theta_axis1) float64 -2.467 -2.256 -1.539 ..."
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inits = inits_to_xarray(fit)\n",
"inits"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"196103384"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seed = fit.get_seed()\n",
"seed"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fit.mode"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# summary"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def summary_to_xarray(fit, pars=None):\n",
" if pars is None:\n",
" pars = fit.model_pars\n",
" \n",
" # TODO: skip pandas DataFrame step?\n",
" \n",
" summary = fit.summary()\n",
" summary_dataset = pd.DataFrame(summary['summary'], \n",
" index=summary['summary_rownames'],\n",
" columns=summary['summary_colnames'],\n",
" ).T.to_xarray()\n",
" \n",
" c_summary = {}\n",
" shape = summary['c_summary'].shape\n",
" if len(shape) == 2:\n",
" shape = np.expand_dims(shape, -1)\n",
" for i in range(shape[-1]):\n",
" c_summary[i] = pd.DataFrame(summary['c_summary'][:, :, i],\n",
" index=summary['c_summary_rownames'],\n",
" columns=summary['c_summary_colnames'],\n",
" ).T.to_xarray()\n",
"\n",
" c_summary_dataset = xr.concat(c_summary.values(), dim='chain')\n",
" \n",
" return summary_dataset, c_summary_dataset"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"summary_dataset, c_summary_dataset = summary_to_xarray(fit)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, index: 7)\n",
"Coordinates:\n",
" * index (index) object 'mean' 'sd' '2.5%' '25%' '50%' '75%' '97.5%'\n",
"Dimensions without coordinates: chain\n",
"Data variables:\n",
" mu (chain, index) float64 4.288 4.24 -2.245 10.89 6.466 ...\n",
" tau (chain, index) float64 3.941 3.326 0.1688 14.67 4.918 ...\n",
" theta_tilde[1] (chain, index) float64 0.339 0.2704 -1.777 2.117 0.9804 ...\n",
" theta_tilde[2] (chain, index) float64 0.02595 0.1136 -1.792 1.743 ...\n",
" theta_tilde[3] (chain, index) float64 -0.03296 -0.0544 -1.86 1.842 ...\n",
" theta_tilde[4] (chain, index) float64 0.08483 0.05905 -1.697 1.968 ...\n",
" theta_tilde[5] (chain, index) float64 -0.1978 -0.1861 -1.897 1.621 ...\n",
" theta_tilde[6] (chain, index) float64 -0.09948 -0.1074 -1.9 1.841 ...\n",
" theta_tilde[7] (chain, index) float64 0.4003 0.3152 -1.411 2.062 0.9367 ...\n",
" theta_tilde[8] (chain, index) float64 0.06687 0.06111 -1.869 1.814 ...\n",
" theta[1] (chain, index) float64 6.584 5.703 -3.185 22.27 8.69 ...\n",
" theta[2] (chain, index) float64 4.497 4.659 -5.909 14.26 7.329 ...\n",
" theta[3] (chain, index) float64 3.899 3.597 -8.291 13.93 7.171 ...\n",
" theta[4] (chain, index) float64 5.065 4.497 -4.252 16.89 7.428 ...\n",
" theta[5] (chain, index) float64 3.291 3.179 -8.917 12.07 6.872 ...\n",
" theta[6] (chain, index) float64 3.802 3.807 -7.491 13.53 6.944 ...\n",
" theta[7] (chain, index) float64 6.7 5.935 -2.707 20.18 8.501 ...\n",
" theta[8] (chain, index) float64 4.876 4.761 -6.37 16.02 7.523 ...\n",
" lp__ (chain, index) float64 -6.796 -6.99 -11.84 -3.071 -5.198 ..."
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c_summary_dataset"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (index: 10)\n",
"Coordinates:\n",
" * index (index) object 'mean' 'se_mean' 'sd' '2.5%' '25%' '50%' ...\n",
"Data variables:\n",
" mu (index) float64 4.317 0.07 3.304 -2.127 2.092 4.318 ...\n",
" tau (index) float64 3.6 0.09325 3.147 0.1331 1.288 2.806 ...\n",
" theta_tilde[1] (index) float64 0.3153 0.02302 0.9943 -1.727 -0.3406 ...\n",
" theta_tilde[2] (index) float64 0.09653 0.01879 0.9062 -1.669 -0.4875 ...\n",
" theta_tilde[3] (index) float64 -0.05765 0.019 0.9542 -1.956 -0.723 ...\n",
" theta_tilde[4] (index) float64 0.07237 0.02249 0.9329 -1.697 -0.5544 ...\n",
" theta_tilde[5] (index) float64 -0.171 0.02014 0.9282 -2.019 -0.7685 ...\n",
" theta_tilde[6] (index) float64 -0.1029 0.01905 0.9502 -1.972 -0.7565 ...\n",
" theta_tilde[7] (index) float64 0.3542 0.0231 0.9637 -1.496 -0.3006 ...\n",
" theta_tilde[8] (index) float64 0.066 0.01956 0.9676 -1.834 -0.5956 ...\n",
" theta[1] (index) float64 6.177 0.1483 5.609 -3.115 2.533 5.56 ...\n",
" theta[2] (index) float64 4.778 0.1034 4.645 -4.062 2.002 4.756 ...\n",
" theta[3] (index) float64 3.827 0.1179 5.271 -7.823 1.16 4.094 ...\n",
" theta[4] (index) float64 4.717 0.1142 4.809 -4.654 1.732 4.444 ...\n",
" theta[5] (index) float64 3.47 0.1 4.784 -6.803 0.9069 3.777 6.405 ...\n",
" theta[6] (index) float64 3.796 0.0987 4.812 -6.678 0.9498 3.941 ...\n",
" theta[7] (index) float64 6.252 0.1274 5.199 -2.891 2.783 5.775 ...\n",
" theta[8] (index) float64 4.726 0.1092 5.102 -5.571 1.689 4.76 ...\n",
" lp__ (index) float64 -6.868 0.08559 2.273 -12.09 -8.263 ..."
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary_dataset"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment