Skip to content

Instantly share code, notes, and snippets.

@ahartikainen
Last active October 14, 2020 17:14
Show Gist options
  • Save ahartikainen/0eb924cea21409600ced23881c156dc5 to your computer and use it in GitHub Desktop.
Save ahartikainen/0eb924cea21409600ced23881c156dc5 to your computer and use it in GitHub Desktop.
Prior predictive with widgets
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import panel as pn\n",
"pn.extension()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cmdstanpy import CmdStanModel\n",
"from cmdstanpy.utils import cxx_toolchain_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import arviz as az"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import platform\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if platform.system() == \"Windows\":\n",
" cxx_toolchain_path();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stan_model = \"\"\"\n",
"data {\n",
" int<lower = 0> N;\n",
" vector[N] x;\n",
" \n",
" // priors\n",
" real alpha_mu;\n",
" real<lower=0> alpha_sd;\n",
" real beta_mu;\n",
" real<lower=0> beta_sd;\n",
" real sigma_nu;\n",
" real sigma_mu;\n",
" real sigma_sd;\n",
"}\n",
"generated quantities {\n",
" real alpha = normal_rng(alpha_mu, alpha_sd);\n",
" real beta = normal_rng(beta_mu, beta_sd);\n",
" real sigma;\n",
" for (i in 1:100) {\n",
" sigma = student_t_rng(sigma_nu, sigma_mu, sigma_sd);\n",
" if (sigma > 0) {\n",
" break;\n",
" }\n",
" }\n",
" real y_sim[N] = normal_rng(alpha + beta * x, sigma);\n",
"}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(\"./prior_predictive.stan\", \"w\") as f:\n",
" print(stan_model, file=f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%time model = CmdStanModel(stan_file=\"./prior_predictive.stan\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fake \"real\" x,y;\n",
"x = np.sort(np.random.rand(14)) * 10\n",
"y = 2.34 * x + np.random.randn(14) + 14.34"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from bokeh.plotting import figure\n",
"from bokeh.layouts import gridplot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_plot(idata):\n",
" p = figure(width=700, height=400, toolbar_location=\"above\")\n",
"\n",
" n = 100\n",
"\n",
" random_sample = np.sort(np.random.choice(idata.prior_predictive.draw, size=n, replace=False))\n",
" x_data = idata.constant_data.x.values\n",
" y_data = idata.prior_predictive.isel({\"draw\": random_sample}).y_sim.values\n",
"\n",
" alpha = idata.prior.alpha.isel({\"draw\": random_sample}).values\n",
" beta = idata.prior.beta.isel({\"draw\": random_sample}).values\n",
" y_sim = alpha + beta * x_data[:, None]\n",
"\n",
" for i in range(n):\n",
" p.circle(x_data, y_data[0, i], fill_color=\"orange\", fill_alpha=0.5, line_color=None)\n",
"\n",
" for i in range(n):\n",
" p.line(x_data, y_sim[:, i], line_color=\"black\", line_alpha=0.3)\n",
"\n",
"\n",
" p.circle(x=idata.constant_data.x_obs.values, y=idata.observed_data.y_obs.values, fill_color=\"red\", fill_alpha=0.9, line_color=None, size=10)\n",
" return p"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def analyze(\n",
" model,\n",
" model_code,\n",
" y_obs,\n",
" x_obs,\n",
" x,\n",
" widgets,\n",
"):\n",
" \n",
" @pn.depends(\n",
" alpha_mu=alpha_mu, \n",
" alpha_sd=alpha_sd, \n",
" beta_mu=beta_mu, \n",
" beta_sd=beta_sd, \n",
" sigma_nu=sigma_nu, \n",
" sigma_mu=sigma_mu, \n",
" sigma_sd=sigma_sd\n",
" )\n",
" def analyze_plots(\n",
" alpha_mu=0,\n",
" alpha_sd=1,\n",
" beta_mu=0,\n",
" beta_sd=1,\n",
" sigma_nu=3,\n",
" sigma_mu=0,\n",
" sigma_sd=1,\n",
" ):\n",
" N = len(x)\n",
" stan_data = dict(\n",
" N=N,\n",
" x=x,\n",
" alpha_mu=float(alpha_mu),\n",
" alpha_sd=float(alpha_sd),\n",
" beta_mu=float(beta_mu),\n",
" beta_sd=float(beta_sd),\n",
" sigma_nu=float(sigma_nu),\n",
" sigma_mu=float(sigma_mu),\n",
" sigma_sd=float(sigma_sd),\n",
" )\n",
" fit = model.sample(data=stan_data, iter_sampling=500, fixed_param=True)\n",
" idata = az.from_cmdstanpy(\n",
" prior=fit, \n",
" prior_predictive=\"y_sim\", \n",
" observed_data={\"y_obs\": y_obs},\n",
" constant_data={\"x_obs\": x_obs, **stan_data}\n",
" )\n",
"\n",
" p_regression = make_plot(idata)\n",
" p_pair = gridplot(az.plot_pair(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], backend=\"bokeh\", backend_kwargs={\"width\": 220, \"height\": 220}, show=False).tolist())\n",
" p_trace = gridplot(az.plot_trace(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], backend=\"bokeh\", show=False).tolist())\n",
" summary_p = az.summary(idata.prior, var_names=[\"alpha\", \"beta\", \"sigma\"], kind=\"stats\")\n",
" \n",
" model_code_pane = pn.pane.Markdown(\"`\"*3+\"stan\\n\"+re.sub(r\"\\n\", \" \\n\", stan_model)+\"`\"*3)\n",
" \n",
" return pn.Column(pn.Row(pn.Column(widgets, summary_p, p_regression), model_code_pane), p_trace, p_pair, width=800)\n",
" \n",
" return analyze_plots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x_pred = np.linspace(0,10,100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"alpha_mu = pn.widgets.TextInput(name='alpha_mu', value=\"0\", width=80)\n",
"alpha_sd = pn.widgets.TextInput(name='alpha_sd', value=\"1\", width=80)\n",
"\n",
"beta_mu = pn.widgets.TextInput(name='beta_mu', value=\"0\", width=80)\n",
"beta_sd = pn.widgets.TextInput(name='beta_sd', value=\"1\", width=80)\n",
"\n",
"sigma_nu = pn.widgets.TextInput(name='sigma_nu', value=\"3\", width=80)\n",
"sigma_mu = pn.widgets.TextInput(name='sigma_mu', value=\"0\", width=80)\n",
"sigma_sd = pn.widgets.TextInput(name='sigma_sd', value=\"1\", width=80)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis = pn.panel(analyze(model, stan_model, y, x, x_pred, pn.Row(alpha_mu, alpha_sd, beta_mu, beta_sd, sigma_nu, sigma_mu, sigma_sd)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis.save(\"prior_predictive_panel\", resources=\"inline\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment