Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created April 20, 2021 09:54
Show Gist options
  • Save ricardoV94/262254a37f5efe6f45e851646e038e37 to your computer and use it in GitHub Desktop.
Save ricardoV94/262254a37f5efe6f45e851646e038e37 to your computer and use it in GitHub Desktop.
Mul Add logp
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import aesara\n",
"import aesara.tensor as at\n",
"from aesara.graph.basic import graph_inputs\n",
"from aesara.graph.fg import FunctionGraph\n",
"from aesara.scalar.basic import Add, Constant, Mul\n",
"from aesara.tensor.random.op import RandomVariable\n",
"from aesara.tensor.var import TensorVariable\n",
"\n",
"import pymc3 as pm\n",
"from pymc3.distributions.logp import _logp, logpt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"@_logp.register(TensorVariable)\n",
"def logp_tensorvariable(op, *args, **kwargs):\n",
" if hasattr(op.owner, 'op') and hasattr(op.owner.op, 'scalar_op'):\n",
" return _logp(op.owner.op.scalar_op, *args, **kwargs)\n",
"\n",
"@_logp.register(at.elemwise.Elemwise)\n",
"def logp_elemwise(op, *args, **kwargs):\n",
" if hasattr(op, 'scalar_op'):\n",
" return _logp(op.scalar_op, *args, **kwargs)\n",
"\n",
"@_logp.register(Mul)\n",
"def mul_logp(op, var, rvs_to_values, mul_left, mul_right, **kwargs):\n",
"\n",
" mul_inputs = [mul_left, mul_right]\n",
" if len(mul_inputs) != 2:\n",
" raise ValueError(f'Expected 2 inputs but got: {len(mul_inputs)}')\n",
"\n",
" rv = [\n",
" inp for inp in mul_inputs\n",
" if inp.owner and isinstance(inp.owner.op, RandomVariable)\n",
" and not hasattr(inp.tag, \"value_var\")\n",
" ]\n",
" scale = [\n",
" inp for inp in mul_inputs\n",
" if not (inp.owner and isinstance(inp.owner.op, RandomVariable))\n",
" or hasattr(inp.tag, \"value_var\")\n",
" ]\n",
"\n",
" if len(rv) != 1:\n",
" raise NotImplementedError(\n",
" f\"Logp for multiplication requires one unregistered RandomVariable but got {len(rv)}\"\n",
" )\n",
"\n",
" rv = rv[0]\n",
" scale = scale[0]\n",
" scale = rvs_to_values.get(rv, getattr(scale.tag, \"value_var\", scale))\n",
"\n",
" new_rvs_to_values = rvs_to_values.copy()\n",
" new_rvs_to_values[rv] = rv\n",
"\n",
" logp = logpt(rv, new_rvs_to_values, **kwargs)\n",
" fgraph = FunctionGraph(\n",
" [i for i in graph_inputs((logp,)) if not isinstance(i, Constant)],\n",
" [logp],\n",
" clone=False,\n",
" )\n",
" fgraph.add_input(scale)\n",
" fgraph.replace(rv, rv / scale)\n",
" logp = fgraph.outputs[0] - at.log(at.abs_(scale))\n",
" return logp\n",
"\n",
"\n",
"@_logp.register(Add)\n",
"def add_logp(op, var, rvs_to_values, add_left, add_right, **kwargs):\n",
"\n",
" add_inputs = [add_left, add_right]\n",
" if len(add_inputs) != 2:\n",
" raise ValueError(f'Expected 2 inputs but got: {len(add_inputs)}')\n",
"\n",
" rv = [\n",
" inp for inp in add_inputs\n",
" if inp.owner and isinstance(inp.owner.op, RandomVariable)\n",
" and not hasattr(inp.tag, \"value_var\")\n",
" ]\n",
"\n",
" loc = [\n",
" inp for inp in add_inputs\n",
" if not (inp.owner and isinstance(inp.owner.op, RandomVariable))\n",
" or hasattr(inp.tag, \"value_var\")\n",
" ]\n",
"\n",
" if len(rv) != 1:\n",
" raise NotImplementedError(\n",
" f\"Logp for addition requires one unregistered RandomVariable but got {len(rv)}\"\n",
" )\n",
"\n",
" rv = rv[0]\n",
" loc = loc[0]\n",
" loc = rvs_to_values.get(rv, getattr(loc.tag, \"value_var\", loc))\n",
"\n",
" new_rvs_to_values = rvs_to_values.copy()\n",
" new_rvs_to_values[rv] = rv\n",
"\n",
" logp = logpt(rv, new_rvs_to_values, **kwargs)\n",
" fgraph = FunctionGraph(\n",
" [i for i in graph_inputs((logp,)) if not isinstance(i, Constant)],\n",
" [logp],\n",
" clone=False,\n",
" )\n",
" fgraph.add_input(loc)\n",
" fgraph.replace(rv, rv - loc)\n",
" logp = fgraph.outputs[0]\n",
" return logp"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n",
" warnings.warn(\n",
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": "[array(-1.61208572), array(-1.73708572), array(-2.11208572)]"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model(check_bounds=False) as m:\n",
" # scale = pm.Normal('scale', 0, 10)\n",
" scale = 2\n",
" x = pm.Normal.dist(mu=0, sigma=1); x.name = 'x'\n",
" y = x * scale; y.name = 'y'\n",
"\n",
"logp = _logp(y.owner.op, y, {y:y}, *y.owner.inputs)\n",
"# f1 = aesara.function([x, scale.tag.value_var], logp)\n",
"# [f1(i, 2) for i in range(3)]\n",
"f1 = aesara.function([x], logp)\n",
"[f1(i) for i in range(3)]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "y"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model(check_bounds=False) as m:\n",
" # scale = pm.Normal('scale', 0, 10)\n",
" scale = 2\n",
" x = pm.Normal.dist(mu=0, sigma=1); x.name = 'x'\n",
" y = x * scale; y.name = 'y'\n",
"\n",
"m.register_rv(y, 'y')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n",
" warnings.warn(\n",
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n",
" warnings.warn(\n",
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n",
" warnings.warn(\n",
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": "(array(-1.6697918), array(-1.63290134))"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m.logp({'y': 0}), m.logp({'y': 0})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elemwise{sub,no_inplace} [id A] '' \n",
" |Elemwise{mul,no_inplace} [id B] '__logp_x' \n",
" | |Elemwise{switch,no_inplace} [id C] '' \n",
" | | |Elemwise{mul,no_inplace} [id D] '' \n",
" | | | |TensorConstant{1} [id E]\n",
" | | | |Elemwise{mul,no_inplace} [id F] '' \n",
" | | | |TensorConstant{1} [id G]\n",
" | | | |Elemwise{gt,no_inplace} [id H] '' \n",
" | | | |Elemwise{mul,no_inplace} [id I] '' \n",
" | | | | |TensorConstant{1.0} [id J]\n",
" | | | | |TensorConstant{1.0} [id K]\n",
" | | | |TensorConstant{0} [id L]\n",
" | | |Elemwise{true_div,no_inplace} [id M] '' \n",
" | | | |Elemwise{add,no_inplace} [id N] '' \n",
" | | | | |Elemwise{mul,no_inplace} [id O] '' \n",
" | | | | | |Elemwise{neg,no_inplace} [id P] '' \n",
" | | | | | | |Elemwise{mul,no_inplace} [id Q] '' \n",
" | | | | | | |TensorConstant{1.0} [id R]\n",
" | | | | | | |Elemwise{pow,no_inplace} [id S] '' \n",
" | | | | | | |TensorConstant{1.0} [id K]\n",
" | | | | | | |TensorConstant{-2.0} [id T]\n",
" | | | | | |Elemwise{pow,no_inplace} [id U] '' \n",
" | | | | | |Elemwise{sub,no_inplace} [id V] '' \n",
" | | | | | | |Elemwise{true_div,no_inplace} [id W] '' \n",
" | | | | | | | |normal_rv.1 [id X] 'x' \n",
" | | | | | | | | |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F70426FD340>) [id Y]\n",
" | | | | | | | | |TensorConstant{[]} [id Z]\n",
" | | | | | | | | |TensorConstant{11} [id BA]\n",
" | | | | | | | | |TensorConstant{0} [id BB]\n",
" | | | | | | | | |TensorConstant{1.0} [id K]\n",
" | | | | | | | |TensorConstant{2} [id BC]\n",
" | | | | | | |TensorConstant{0} [id BB]\n",
" | | | | | |TensorConstant{2} [id BD]\n",
" | | | | |Elemwise{log,no_inplace} [id BE] '' \n",
" | | | | |Elemwise{true_div,no_inplace} [id BF] '' \n",
" | | | | |Elemwise{true_div,no_inplace} [id BG] '' \n",
" | | | | | |Elemwise{mul,no_inplace} [id Q] '' \n",
" | | | | | |TensorConstant{3.141592653589793} [id BH]\n",
" | | | | |TensorConstant{2.0} [id BI]\n",
" | | | |TensorConstant{2.0} [id BJ]\n",
" | | |TensorConstant{-inf} [id BK]\n",
" | |TensorConstant{1.0} [id BL]\n",
" |Elemwise{log,no_inplace} [id BM] '' \n",
" |Elemwise{abs_,no_inplace} [id BN] '' \n",
" |TensorConstant{2} [id BC]\n"
]
}
],
"source": [
"aesara.dprint(logp)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment