Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active November 8, 2022 18:37
Show Gist options
  • Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
Save ricardoV94/0cf8fd0f69a09d7eff0a5b41cb111965 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"import aesara\n",
"import aesara.tensor as at\n",
"import numpy as np\n",
"\n",
"from aeppl import factorized_joint_logprob, joint_logprob"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"def marginalize_factorized_logp_dict(logp_dict, marginalize):\n",
" if not marginalize:\n",
" return logp_dict\n",
" else:\n",
" marginalize = marginalize.copy()\n",
" marginalized_vv, constant_values = marginalize.popitem()\n",
" marginalized_logp_dict = defaultdict(lambda: at.constant(-np.inf))\n",
"\n",
" for constant_value in constant_values:\n",
" constant_value = at.constant(\n",
" constant_value,\n",
" dtype=marginalized_vv.dtype,\n",
" name=f\"{marginalized_vv}={constant_value}\",\n",
" )\n",
" new_logp_dict = {\n",
" vv: logp_expr for vv, logp_expr in zip(\n",
" logp_dict.keys(),\n",
" aesara.graph.clone_replace(\n",
" list(logp_dict.values()),\n",
" replace={marginalized_vv: constant_value},\n",
" )\n",
" )\n",
" }\n",
" marginalized_var_constant_logp = new_logp_dict.pop(marginalized_vv)\n",
" new_logp_dict = marginalize_factorized_logp_dict(new_logp_dict, marginalize)\n",
"\n",
" for value_var, logp_expr in new_logp_dict.items():\n",
" marginalized_logp_dict[value_var] = at.logsumexp((\n",
" marginalized_logp_dict[value_var],\n",
" logp_expr + marginalized_var_constant_logp\n",
" ))\n",
"\n",
" return marginalized_logp_dict"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"x = at.random.categorical([0.1, 0.2, 0.3, 0.4], name=\"x\")\n",
"y = at.random.categorical([0.1, 0.3, 0.6], name=\"y\")\n",
"z = at.random.normal(x, y + 1, name=\"z\")\n",
"x_vv = x.clone()\n",
"y_vv = y.clone()\n",
"z_vv = z.clone()\n",
"\n",
"ref_logp = joint_logprob({x: x_vv, y: y_vv, z: z_vv}, sum=True)\n",
"ref_logp_fn = aesara.function([x_vv, y_vv, z_vv], ref_logp)\n",
"\n",
"logp_dict = factorized_joint_logprob({x:x_vv, y:y_vv, z: z_vv})\n",
"logp_dict = marginalize_factorized_logp_dict(logp_dict, marginalize={x_vv: range(4), y_vv: range(3)})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "array(-1.97233263)"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logp_dict[z_vv].eval({z_vv: 1})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "-1.9723326261602925"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.log(np.sum([np.exp(ref_logp_fn(x=x_vv, y=y_vv, z=1)) for x_vv in range(4) for y_vv in range(3)]))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment