Created
April 20, 2021 09:54
-
-
Save ricardoV94/262254a37f5efe6f45e851646e038e37 to your computer and use it in GitHub Desktop.
Mul Add logp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import aesara\n", | |
"import aesara.tensor as at\n", | |
"from aesara.graph.basic import graph_inputs\n", | |
"from aesara.graph.fg import FunctionGraph\n", | |
"from aesara.scalar.basic import Add, Constant, Mul\n", | |
"from aesara.tensor.random.op import RandomVariable\n", | |
"from aesara.tensor.var import TensorVariable\n", | |
"\n", | |
"import pymc3 as pm\n", | |
"from pymc3.distributions.logp import _logp, logpt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"outputs": [], | |
"source": [ | |
"@_logp.register(TensorVariable)\n", | |
"def logp_tensorvariable(op, *args, **kwargs):\n", | |
" if hasattr(op.owner, 'op') and hasattr(op.owner.op, 'scalar_op'):\n", | |
" return _logp(op.owner.op.scalar_op, *args, **kwargs)\n", | |
"\n", | |
"@_logp.register(at.elemwise.Elemwise)\n", | |
"def logp_elemwise(op, *args, **kwargs):\n", | |
" if hasattr(op, 'scalar_op'):\n", | |
" return _logp(op.scalar_op, *args, **kwargs)\n", | |
"\n", | |
"@_logp.register(Mul)\n", | |
"def mul_logp(op, var, rvs_to_values, mul_left, mul_right, **kwargs):\n", | |
"\n", | |
" mul_inputs = [mul_left, mul_right]\n", | |
" if len(mul_inputs) != 2:\n", | |
" raise ValueError(f'Expected 2 inputs but got: {len(mul_inputs)}')\n", | |
"\n", | |
" rv = [\n", | |
" inp for inp in mul_inputs\n", | |
" if inp.owner and isinstance(inp.owner.op, RandomVariable)\n", | |
" and not hasattr(inp.tag, \"value_var\")\n", | |
" ]\n", | |
" scale = [\n", | |
" inp for inp in mul_inputs\n", | |
" if not (inp.owner and isinstance(inp.owner.op, RandomVariable))\n", | |
" or hasattr(inp.tag, \"value_var\")\n", | |
" ]\n", | |
"\n", | |
" if len(rv) != 1:\n", | |
" raise NotImplementedError(\n", | |
" f\"Logp for multiplication requires one unregistered RandomVariable but got {len(rv)}\"\n", | |
" )\n", | |
"\n", | |
" rv = rv[0]\n", | |
" scale = scale[0]\n", | |
" scale = rvs_to_values.get(rv, getattr(scale.tag, \"value_var\", scale))\n", | |
"\n", | |
" new_rvs_to_values = rvs_to_values.copy()\n", | |
" new_rvs_to_values[rv] = rv\n", | |
"\n", | |
" logp = logpt(rv, new_rvs_to_values, **kwargs)\n", | |
" fgraph = FunctionGraph(\n", | |
" [i for i in graph_inputs((logp,)) if not isinstance(i, Constant)],\n", | |
" [logp],\n", | |
" clone=False,\n", | |
" )\n", | |
" fgraph.add_input(scale)\n", | |
" fgraph.replace(rv, rv / scale)\n", | |
" logp = fgraph.outputs[0] - at.log(at.abs_(scale))\n", | |
" return logp\n", | |
"\n", | |
"\n", | |
"@_logp.register(Add)\n", | |
"def add_logp(op, var, rvs_to_values, add_left, add_right, **kwargs):\n", | |
"\n", | |
" add_inputs = [add_left, add_right]\n", | |
" if len(add_inputs) != 2:\n", | |
" raise ValueError(f'Expected 2 inputs but got: {len(add_inputs)}')\n", | |
"\n", | |
" rv = [\n", | |
" inp for inp in add_inputs\n", | |
" if inp.owner and isinstance(inp.owner.op, RandomVariable)\n", | |
" and not hasattr(inp.tag, \"value_var\")\n", | |
" ]\n", | |
"\n", | |
" loc = [\n", | |
" inp for inp in add_inputs\n", | |
" if not (inp.owner and isinstance(inp.owner.op, RandomVariable))\n", | |
" or hasattr(inp.tag, \"value_var\")\n", | |
" ]\n", | |
"\n", | |
" if len(rv) != 1:\n", | |
" raise NotImplementedError(\n", | |
" f\"Logp for addition requires one unregistered RandomVariable but got {len(rv)}\"\n", | |
" )\n", | |
"\n", | |
" rv = rv[0]\n", | |
" loc = loc[0]\n", | |
" loc = rvs_to_values.get(rv, getattr(loc.tag, \"value_var\", loc))\n", | |
"\n", | |
" new_rvs_to_values = rvs_to_values.copy()\n", | |
" new_rvs_to_values[rv] = rv\n", | |
"\n", | |
" logp = logpt(rv, new_rvs_to_values, **kwargs)\n", | |
" fgraph = FunctionGraph(\n", | |
" [i for i in graph_inputs((logp,)) if not isinstance(i, Constant)],\n", | |
" [logp],\n", | |
" clone=False,\n", | |
" )\n", | |
" fgraph.add_input(loc)\n", | |
" fgraph.replace(rv, rv - loc)\n", | |
" logp = fgraph.outputs[0]\n", | |
" return logp" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n", | |
" warnings.warn(\n", | |
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": "[array(-1.61208572), array(-1.73708572), array(-2.11208572)]" | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model(check_bounds=False) as m:\n", | |
" # scale = pm.Normal('scale', 0, 10)\n", | |
" scale = 2\n", | |
" x = pm.Normal.dist(mu=0, sigma=1); x.name = 'x'\n", | |
" y = x * scale; y.name = 'y'\n", | |
"\n", | |
"logp = _logp(y.owner.op, y, {y:y}, *y.owner.inputs)\n", | |
"# f1 = aesara.function([x, scale.tag.value_var], logp)\n", | |
"# [f1(i, 2) for i in range(3)]\n", | |
"f1 = aesara.function([x], logp)\n", | |
"[f1(i) for i in range(3)]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "y" | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model(check_bounds=False) as m:\n", | |
" # scale = pm.Normal('scale', 0, 10)\n", | |
" scale = 2\n", | |
" x = pm.Normal.dist(mu=0, sigma=1); x.name = 'x'\n", | |
" y = x * scale; y.name = 'y'\n", | |
"\n", | |
"m.register_rv(y, 'y')" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n", | |
" warnings.warn(\n", | |
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n", | |
" warnings.warn(\n", | |
"/home/ricardo/Documents/Projects/pymc3/pymc3/aesaraf.py:334: UserWarning: No value variable found for x; the random variable will not be replaced.\n", | |
" warnings.warn(\n", | |
"/home/ricardo/Documents/Projects/pymc3-venv/lib/python3.8/site-packages/aesara/graph/fg.py:525: UserWarning: Variable y cannot be replaced; it isn't in the FunctionGraph\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": "(array(-1.6697918), array(-1.63290134))" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"m.logp({'y': 0}), m.logp({'y': 0})" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Elemwise{sub,no_inplace} [id A] '' \n", | |
" |Elemwise{mul,no_inplace} [id B] '__logp_x' \n", | |
" | |Elemwise{switch,no_inplace} [id C] '' \n", | |
" | | |Elemwise{mul,no_inplace} [id D] '' \n", | |
" | | | |TensorConstant{1} [id E]\n", | |
" | | | |Elemwise{mul,no_inplace} [id F] '' \n", | |
" | | | |TensorConstant{1} [id G]\n", | |
" | | | |Elemwise{gt,no_inplace} [id H] '' \n", | |
" | | | |Elemwise{mul,no_inplace} [id I] '' \n", | |
" | | | | |TensorConstant{1.0} [id J]\n", | |
" | | | | |TensorConstant{1.0} [id K]\n", | |
" | | | |TensorConstant{0} [id L]\n", | |
" | | |Elemwise{true_div,no_inplace} [id M] '' \n", | |
" | | | |Elemwise{add,no_inplace} [id N] '' \n", | |
" | | | | |Elemwise{mul,no_inplace} [id O] '' \n", | |
" | | | | | |Elemwise{neg,no_inplace} [id P] '' \n", | |
" | | | | | | |Elemwise{mul,no_inplace} [id Q] '' \n", | |
" | | | | | | |TensorConstant{1.0} [id R]\n", | |
" | | | | | | |Elemwise{pow,no_inplace} [id S] '' \n", | |
" | | | | | | |TensorConstant{1.0} [id K]\n", | |
" | | | | | | |TensorConstant{-2.0} [id T]\n", | |
" | | | | | |Elemwise{pow,no_inplace} [id U] '' \n", | |
" | | | | | |Elemwise{sub,no_inplace} [id V] '' \n", | |
" | | | | | | |Elemwise{true_div,no_inplace} [id W] '' \n", | |
" | | | | | | | |normal_rv.1 [id X] 'x' \n", | |
" | | | | | | | | |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F70426FD340>) [id Y]\n", | |
" | | | | | | | | |TensorConstant{[]} [id Z]\n", | |
" | | | | | | | | |TensorConstant{11} [id BA]\n", | |
" | | | | | | | | |TensorConstant{0} [id BB]\n", | |
" | | | | | | | | |TensorConstant{1.0} [id K]\n", | |
" | | | | | | | |TensorConstant{2} [id BC]\n", | |
" | | | | | | |TensorConstant{0} [id BB]\n", | |
" | | | | | |TensorConstant{2} [id BD]\n", | |
" | | | | |Elemwise{log,no_inplace} [id BE] '' \n", | |
" | | | | |Elemwise{true_div,no_inplace} [id BF] '' \n", | |
" | | | | |Elemwise{true_div,no_inplace} [id BG] '' \n", | |
" | | | | | |Elemwise{mul,no_inplace} [id Q] '' \n", | |
" | | | | | |TensorConstant{3.141592653589793} [id BH]\n", | |
" | | | | |TensorConstant{2.0} [id BI]\n", | |
" | | | |TensorConstant{2.0} [id BJ]\n", | |
" | | |TensorConstant{-inf} [id BK]\n", | |
" | |TensorConstant{1.0} [id BL]\n", | |
" |Elemwise{log,no_inplace} [id BM] '' \n", | |
" |Elemwise{abs_,no_inplace} [id BN] '' \n", | |
" |TensorConstant{2} [id BC]\n" | |
] | |
} | |
], | |
"source": [ | |
"aesara.dprint(logp)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"outputs": [], | |
"source": [], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment