Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active November 8, 2022 18:37
Show Gist options
  • Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "7205a740",
"metadata": {},
"source": [
"# Marginalizing discrete RVs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6a238994",
"metadata": {},
"outputs": [],
"source": [
"import aeppl\n",
"import aesara\n",
"import aesara.tensor as at\n",
"from aesara.graph import FunctionGraph\n",
"from aesara.compile.builders import OpFromGraph\n",
"import numpy as np\n",
"\n",
"import pymc as pm"
]
},
{
"cell_type": "markdown",
"id": "e05a3823",
"metadata": {},
"source": [
"## Marginalizing a single RV"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cf25b739",
"metadata": {},
"outputs": [],
"source": [
"with pm.Model() as m:\n",
" p = pm.Dirichlet(\"p\", [1, 1])\n",
" x = pm.Categorical(\"x\", p=p)\n",
" y = pm.Normal(\"y\", pm.math.stack([-1, 1])[x], 1, observed=1) "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f94e5a6",
"metadata": {},
"outputs": [],
"source": [
"p_vv = m.rvs_to_values[p]\n",
"x_vv = m.rvs_to_values[x]\n",
"logp = m.logp()\n",
"logp_op = OpFromGraph([p_vv, x_vv], [logp], inline=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d1a4e4f7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OpFromGraph{inline=True}.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logp_op(p_vv, x_vv)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6f626312",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'p_simplex__': array([0.]), 'x': array(0)}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m.initial_point()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "872648d2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array(-4.30523289), array(-2.30523289))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logp_op(np.array([0]), 0).eval(), logp_op(np.array([0]), 1).eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7367bbe3",
"metadata": {},
"outputs": [],
"source": [
"x_domain = range(2) # Possible values of the categorical\n",
"marginal_logp = at.logsumexp([logp_op(p_vv, x_vv_const) for x_vv_const in x_domain])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7df0efd4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-2.17830488)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"marginal_logp.eval({p_vv: np.array([0])})"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "35e97418",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-2.17830488)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model() as m_ref:\n",
" p = pm.Dirichlet(\"p\", [1, 1])\n",
" y = pm.NormalMixture(\"y\", w=p, mu=[-1, 1], sigma=1, observed=1) \n",
"m_ref.compile_logp()({\"p_simplex__\": np.array([0])})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c4846826",
"metadata": {},
"outputs": [],
"source": [
"f = aesara.function([p_vv], marginal_logp)\n",
"# aesara.dprint(f)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c5f15010",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"52"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(f.maker.fgraph.apply_nodes)"
]
},
{
"cell_type": "markdown",
"id": "37ef8d2b",
"metadata": {},
"source": [
"## Marginalize multiple RVs"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "81d63419",
"metadata": {},
"outputs": [],
"source": [
"def explicit_mixture(name, categorical_idx, components):\n",
" return pm.Normal(name, pm.math.stack(components)[categorical_idx], 1)\n",
" \n",
"with pm.Model() as m:\n",
" p1 = pm.Dirichlet(\"p1\", [1, 1])\n",
" mix_comp1 = pm.Categorical(\"mix_comp1\", p=p1) \n",
" y1 = explicit_mixture(\"y1\", mix_comp1, [-1, 1])\n",
" \n",
" p2 = pm.Dirichlet(\"p2\", [1, 1])\n",
" mix_comp2 = pm.Categorical(\"mix_comp2\", p=p2) \n",
" y2 = explicit_mixture(\"y2\", mix_comp2, [-2, 2])\n",
" \n",
" p3 = pm.Dirichlet(\"p3\", [1, 1])\n",
" mix_comp3 = pm.Categorical(\"mix_comp3\", p=p3)\n",
" y3 = explicit_mixture(\"y3\", mix_comp3, [y1, y2])\n",
" \n",
" pm.Normal(\"llike\", y3, 1, observed=9)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bf922bc3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mix_comp3\n",
"mix_comp2\n",
"mix_comp1\n"
]
}
],
"source": [
"logp_graph = m.logp()\n",
"rvs = list(m.free_RVs)\n",
"marginalize_rvs = {mix_comp1, mix_comp2, mix_comp3}\n",
"fg = FunctionGraph(outputs=rvs, clone=False)\n",
"order = fg.toposort()\n",
"for rv in sorted(marginalize_rvs, key=lambda x: order.index(x.owner), reverse=True):\n",
" print(rv)\n",
" rvs.remove(rv)\n",
" vv = m.rvs_to_values[rv]\n",
" vvs = [m.rvs_to_values[rv] for rv in rvs]\n",
" logp_op = OpFromGraph([vv, *vvs], [logp_graph], inline=True)\n",
" rv_domain = range(2) # Hard-coded\n",
" logp_graph = at.logsumexp([logp_op(vv_const, *vvs) for vv_const in rv_domain])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5c59cafa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'p1_simplex__': array([0.]),\n",
" 'y1': array(-1.),\n",
" 'p2_simplex__': array([0.]),\n",
" 'y2': array(-2.),\n",
" 'p3_simplex__': array([0.]),\n",
" 'y3': array(-1.)}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ip = m.initial_point()\n",
"ip.pop(\"mix_comp3\", None)\n",
"ip.pop(\"mix_comp2\", None)\n",
"ip.pop(\"mix_comp1\", None)\n",
"ip"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8a2f1de8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-57.23329681)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = m.compile_fn(logp_graph)\n",
"f(ip)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2c8ba2ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"193"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(f.f.maker.fgraph.apply_nodes)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c5ef7a5c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-57.23329681)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model() as m_ref:\n",
" p1 = pm.Dirichlet(\"p1\", [1, 1])\n",
" y1 = pm.NormalMixture(\"y1\", p1, [-1, 1])\n",
" \n",
" p2 = pm.Dirichlet(\"p2\", [1, 1])\n",
" y2 = pm.NormalMixture(\"y2\", p2, [-2, 2])\n",
" \n",
" p3 = pm.Dirichlet(\"p3\", [1, 1])\n",
" y3 = pm.NormalMixture(\"y3\", p3, [y1, y2])\n",
" \n",
" pm.Normal(\"llike\", y3, 1, observed=9)\n",
" \n",
"m_ref.compile_logp()(ip)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98072f1c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment