Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active November 18, 2023 10:09
Show Gist options
  • Save ricardoV94/99c53fbb8b2e9a68e1b2c6c4d761eaf4 to your computer and use it in GitHub Desktop.
Save ricardoV94/99c53fbb8b2e9a68e1b2c6c4d761eaf4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"id": "fc9b5316",
"metadata": {},
"outputs": [],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from pytensor.graph.fg import FunctionGraph\n",
"from pytensor.graph.replace import vectorize_graph\n",
"from pytensor.graph.replace import _vectorize_node\n",
"from pytensor.graph.rewriting.basic import MergeOptimizer\n",
"\n",
"import numpy as np\n",
"import pymc as pm\n",
"\n",
"from pymc.distributions import transforms as tr\n",
"\n",
"from pymc.model.fgraph import (\n",
" fgraph_from_model,\n",
" model_from_fgraph,\n",
" ModelFreeRV,\n",
" ModelObservedRV,\n",
" ModelNamed,\n",
" model_free_rv,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4acb13f5",
"metadata": {},
"outputs": [],
"source": [
"@_vectorize_node.register(ModelNamed)\n",
"def vectorize_model_named(op, node, new_value):\n",
" name = node.outputs[0].name\n",
" new_value.name = name\n",
" new_model_named = op(new_value)\n",
" new_model_named.name = name \n",
" return new_model_named.owner\n",
"\n",
"@_vectorize_node.register(ModelObservedRV)\n",
"@_vectorize_node.register(ModelFreeRV)\n",
"def vectorize_model_free_rv(op, node, new_rv, new_value):\n",
" old_rv, old_value = node.inputs\n",
" \n",
"# print(node)\n",
" \n",
" # Check what changed\n",
" batch_dims = new_rv.ndim - old_rv.ndim\n",
" batch_shape = ()\n",
" if batch_dims:\n",
" batch_type_shape = new_rv.type.shape[:batch_dims]\n",
" batch_shape = new_rv.shape[:batch_dims]\n",
" \n",
" if new_value.ndim > old_value.ndim:\n",
" # Raise NotImplementdeError for partial batching \n",
" if new_value.ndim - old_value.ndim != batch_dims:\n",
" raise NotImplementedError()\n",
" else:\n",
" # Batch values\n",
" if isinstance(op, ModelFreeRV):\n",
" new_type_shape = tuple(batch_type_shape) + new_value.type.shape\n",
" new_value = new_value.type.clone(shape=new_type_shape)(old_value.name)\n",
" else:\n",
" new_shape = tuple(batch_shape) + tuple(new_value.shape)\n",
" new_value = pt.broadcast_to(new_value, new_shape)\n",
" \n",
" new_rv.name = old_rv.name\n",
" out = op(new_rv, new_value)\n",
" out.name = old_rv.name\n",
" return out.owner"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fafba8e8",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"450pt\" height=\"433pt\"\n",
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>y</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- x -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>x</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- obs -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>obs</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n",
"</g>\n",
"<!-- x&#45;&gt;obs -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>x&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n",
"</g>\n",
"<!-- b1 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>b1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1&#45;&gt;obs -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>b1&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n",
"</g>\n",
"<!-- obs&#45;&gt;y -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>obs&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n",
"</g>\n",
"<!-- b1_std -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>b1_std</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- b1_std&#45;&gt;b1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>b1_std&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n",
"</g>\n",
"<!-- b1_mean -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>b1_mean</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1_mean&#45;&gt;b1 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>b1_mean&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n",
"</g>\n",
"<!-- b0 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>b0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b0&#45;&gt;obs -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>b0&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7ff4b7bbe620>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model() as target_model:\n",
" x = pm.ConstantData(\"x\", np.arange(10, dtype=float))\n",
" y = pm.ConstantData(\"y\", np.ones(10, dtype=int))\n",
" \n",
" b1_mean = pm.Normal(\"b1_mean\", 0, 1)\n",
" b1_std = pm.HalfNormal(\"b1_std\", 1)\n",
" b0 = pm.Normal(\"b0\")\n",
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=(10,))\n",
" \n",
" logit_p = b0 + b1 * x\n",
" pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n",
" \n",
"target_model.to_graphviz()"
]
},
{
"cell_type": "markdown",
"id": "cc09069c",
"metadata": {},
"source": [
"## From a core model with hyperpriors"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d843a303",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"357pt\" height=\"394pt\"\n",
" viewBox=\"0.00 0.00 357.00 393.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 389.86)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-389.86 353,-389.86 353,4 -4,4\"/>\n",
"<!-- x -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>x</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M103,-263.93C103,-263.93 12,-263.93 12,-263.93 6,-263.93 0,-257.93 0,-251.93 0,-251.93 0,-222.93 0,-222.93 0,-216.93 6,-210.93 12,-210.93 12,-210.93 103,-210.93 103,-210.93 109,-210.93 115,-216.93 115,-222.93 115,-222.93 115,-251.93 115,-251.93 115,-257.93 109,-263.93 103,-263.93\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- obs -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>obs</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"182.5\" cy=\"-126.48\" rx=\"57.97\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-137.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-122.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-107.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n",
"</g>\n",
"<!-- x&#45;&gt;obs -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>x&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M86.81,-210.89C102.83,-196.92 122.94,-179.39 140.6,-164\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"143.27,-166.31 148.51,-157.1 138.67,-161.04 143.27,-166.31\"/>\n",
"</g>\n",
"<!-- b1 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>b1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"182.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1&#45;&gt;obs -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>b1&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-199.85C182.5,-191.67 182.5,-182.89 182.5,-174.37\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-174.15 182.5,-164.15 179,-174.15 186,-174.15\"/>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>y</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M228,-53C228,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,0 137,0 137,0 228,0 228,0 234,0 240,-6 240,-12 240,-12 240,-41 240,-41 240,-47 234,-53 228,-53\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- b0 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>b0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"299.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b0&#45;&gt;obs -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>b0&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M268.76,-207.8C254.69,-194.7 237.84,-179.01 222.77,-164.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"224.78,-162.07 215.08,-157.82 220.01,-167.19 224.78,-162.07\"/>\n",
"</g>\n",
"<!-- b1_std -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>b1_std</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"110.5\" cy=\"-348.38\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-359.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n",
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-344.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-329.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- b1_std&#45;&gt;b1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>b1_std&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M133.48,-312.61C140.36,-302.19 147.98,-290.66 155.12,-279.86\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"158.13,-281.66 160.72,-271.39 152.29,-277.8 158.13,-281.66\"/>\n",
"</g>\n",
"<!-- b1_mean -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>b1_mean</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"255.5\" cy=\"-348.38\" rx=\"56.64\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-359.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n",
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-344.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-329.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1_mean&#45;&gt;b1 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>b1_mean&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M233,-313.8C225.79,-303.04 217.72,-291 210.18,-279.75\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"213.02,-277.7 204.55,-271.34 207.21,-281.59 213.02,-277.7\"/>\n",
"</g>\n",
"<!-- obs&#45;&gt;y -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>obs&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-88.99C182.5,-80.58 182.5,-71.63 182.5,-63.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-63.01 182.5,-53.01 179,-63.01 186,-63.01\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7ff4b77dfd00>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model() as core_model:\n",
" x = pm.ConstantData(\"x\", 0.0)\n",
" y = pm.ConstantData(\"y\", 1)\n",
" \n",
" b1_mean = pm.Normal(\"b1_mean\", 0, 1)\n",
" b1_std = pm.HalfNormal(\"b1_std\", 1)\n",
" \n",
" b0 = pm.Normal(\"b0\")\n",
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=())\n",
"\n",
" \n",
" logit_p = b0 + b1 * x\n",
" obs = pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n",
" \n",
"core_model.to_graphviz()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f7d1cd8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelObservedRV{transform=None} [id A] 'obs' 15\n",
" ├─ bernoulli_rv{0, (0,), int64, False}.1 [id B] 'obs' 14\n",
" │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8580>) [id C]\n",
" │ ├─ [] [id D]\n",
" │ ├─ 4 [id E]\n",
" │ └─ Sigmoid [id F] 13\n",
" │ └─ Add [id G] 12\n",
" │ ├─ ModelFreeRV{transform=None} [id H] 'b0' 11\n",
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id I] 'b0' 10\n",
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9700>) [id J]\n",
" │ │ │ ├─ [] [id K]\n",
" │ │ │ ├─ 11 [id L]\n",
" │ │ │ ├─ 0 [id M]\n",
" │ │ │ └─ 1.0 [id N]\n",
" │ │ └─ b0 [id O]\n",
" │ └─ Mul [id P] 9\n",
" │ ├─ ModelFreeRV{transform=None} [id Q] 'b1' 8\n",
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id R] 'b1' 7\n",
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A92A0>) [id S]\n",
" │ │ │ ├─ [] [id T]\n",
" │ │ │ ├─ 11 [id U]\n",
" │ │ │ ├─ ModelFreeRV{transform=None} [id V] 'b1_mean' 6\n",
" │ │ │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id W] 'b1_mean' 5\n",
" │ │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9A80>) [id X]\n",
" │ │ │ │ │ ├─ [] [id Y]\n",
" │ │ │ │ │ ├─ 11 [id Z]\n",
" │ │ │ │ │ ├─ 0 [id BA]\n",
" │ │ │ │ │ └─ 1.0 [id BB]\n",
" │ │ │ │ └─ b1_mean [id BC]\n",
" │ │ │ └─ ModelFreeRV{transform=<pymc.logprob.transforms.LogTransform object at 0x7ff4c45ba6b0>} [id BD] 'b1_std' 4\n",
" │ │ │ ├─ halfnormal_rv{0, (0, 0), floatX, False}.1 [id BE] 'b1_std' 3\n",
" │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9540>) [id BF]\n",
" │ │ │ │ ├─ [] [id BG]\n",
" │ │ │ │ ├─ 11 [id BH]\n",
" │ │ │ │ ├─ 0.0 [id BI]\n",
" │ │ │ │ └─ 1.0 [id BJ]\n",
" │ │ │ └─ b1_std_log__ [id BK]\n",
" │ │ └─ b1 [id BL]\n",
" │ └─ ModelNamed [id BM] 'x' 2\n",
" │ └─ x{0.0} [id BN]\n",
" └─ Cast{int64} [id BO] 1\n",
" └─ ModelNamed [id BP] 'y' 0\n",
" └─ y{1.0} [id BQ]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7ff53cef9c60>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_, memo = fgraph_from_model(core_model, inlined_views=True)\n",
"# The vectorize for FreeRV will create multiple unique value variables everytime it is called. \n",
"# The standard representation includes the each FreeRV as an output and wherever it is used in the graph\n",
"# Vectorize would therefore find it at least twice\n",
"core_fgraph = FunctionGraph(outputs=[memo[obs]], clone=False)\n",
"pytensor.dprint(core_fgraph, print_type=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "84222351",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelObservedRV{transform=None} [id A] 'obs' 16\n",
" ├─ bernoulli_rv{0, (0,), int64, False}.1 [id B] 'obs' 15\n",
" │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8580>) [id C]\n",
" │ ├─ [] [id D]\n",
" │ ├─ 4 [id E]\n",
" │ └─ Sigmoid [id F] 14\n",
" │ └─ Add [id G] 13\n",
" │ ├─ ExpandDims{axis=0} [id H] 12\n",
" │ │ └─ ModelFreeRV{transform=None} [id I] 'b0' 11\n",
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id J] 'b0' 10\n",
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9700>) [id K]\n",
" │ │ │ ├─ [] [id D]\n",
" │ │ │ ├─ 11 [id L]\n",
" │ │ │ ├─ 0 [id M]\n",
" │ │ │ └─ 1.0 [id N]\n",
" │ │ └─ b0 [id O]\n",
" │ └─ Mul [id P] 9\n",
" │ ├─ ModelFreeRV{transform=None} [id Q] 'b1' 8\n",
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id R] 'b1' 7\n",
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8820>) [id S]\n",
" │ │ │ ├─ [10] [id T]\n",
" │ │ │ ├─ 11 [id L]\n",
" │ │ │ ├─ ModelFreeRV{transform=None} [id U] 'b1_mean' 6\n",
" │ │ │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id V] 'b1_mean' 5\n",
" │ │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9A80>) [id W]\n",
" │ │ │ │ │ ├─ [] [id D]\n",
" │ │ │ │ │ ├─ 11 [id L]\n",
" │ │ │ │ │ ├─ 0 [id M]\n",
" │ │ │ │ │ └─ 1.0 [id N]\n",
" │ │ │ │ └─ b1_mean [id X]\n",
" │ │ │ └─ ModelFreeRV{transform=<pymc.logprob.transforms.LogTransform object at 0x7ff4c45ba6b0>} [id Y] 'b1_std' 4\n",
" │ │ │ ├─ halfnormal_rv{0, (0, 0), floatX, False}.1 [id Z] 'b1_std' 3\n",
" │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9540>) [id BA]\n",
" │ │ │ │ ├─ [] [id D]\n",
" │ │ │ │ ├─ 11 [id L]\n",
" │ │ │ │ ├─ 0.0 [id BB]\n",
" │ │ │ │ └─ 1.0 [id N]\n",
" │ │ │ └─ b1_std_log__ [id BC]\n",
" │ │ └─ b1 [id BD]\n",
" │ └─ ModelNamed [id BE] 'x' 2\n",
" │ └─ x{[0. 1. 2. ... 7. 8. 9.]} [id BF]\n",
" └─ Cast{int64} [id BG] 1\n",
" └─ ModelNamed [id BH] 'y' 0\n",
" └─ y{[1 1 1 1 1 ... 1 1 1 1 1]} [id BI]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7ff53cef9c60>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"replace = {\n",
" memo[b1].owner.inputs[0]: pm.Normal.dist(memo[b1_mean], memo[b1_std], shape=(10,)),\n",
" memo[x].owner.inputs[0]: pt.as_tensor(np.arange(10, dtype=float)),\n",
" memo[y].owner.inputs[0]: pt.as_tensor(np.ones(10, dtype=int))\n",
"}\n",
"\n",
"vect_fgraph = FunctionGraph(\n",
" outputs=vectorize_graph(core_fgraph.outputs, replace=replace)\n",
")\n",
"MergeOptimizer().rewrite(vect_fgraph)\n",
"pytensor.dprint(vect_fgraph, print_type=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1ecf5dbf",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"450pt\" height=\"433pt\"\n",
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>y</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- x -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>x</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- obs -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>obs</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n",
"</g>\n",
"<!-- x&#45;&gt;obs -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>x&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n",
"</g>\n",
"<!-- b1 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>b1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1&#45;&gt;obs -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>b1&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n",
"</g>\n",
"<!-- obs&#45;&gt;y -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>obs&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n",
"</g>\n",
"<!-- b1_std -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>b1_std</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- b1_std&#45;&gt;b1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>b1_std&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n",
"</g>\n",
"<!-- b1_mean -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>b1_mean</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1_mean&#45;&gt;b1 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>b1_mean&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n",
"</g>\n",
"<!-- b0 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>b0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b0&#45;&gt;obs -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>b0&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7ff4b777a2c0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vect_model = model_from_fgraph(vect_fgraph)\n",
"vect_model.to_graphviz()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb81b3b2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('b0', -0.92),\n",
" ('b1', -9.19),\n",
" ('b1_mean', -0.92),\n",
" ('b1_std', -0.73),\n",
" ('obs', -6.93)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(vect_model.point_logps().items())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "35aeb105",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('b0', -0.92),\n",
" ('b1', -9.19),\n",
" ('b1_mean', -0.92),\n",
" ('b1_std', -0.73),\n",
" ('obs', -6.93)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(target_model.point_logps().items())"
]
},
{
"cell_type": "markdown",
"id": "88546c7d",
"metadata": {},
"source": [
"## From a core model without hyperpriors"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "dcf48a7e",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"357pt\" height=\"283pt\"\n",
" viewBox=\"0.00 0.00 357.00 282.91\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 278.91)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-278.91 353,-278.91 353,4 -4,4\"/>\n",
"<!-- x -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>x</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M103,-263.93C103,-263.93 12,-263.93 12,-263.93 6,-263.93 0,-257.93 0,-251.93 0,-251.93 0,-222.93 0,-222.93 0,-216.93 6,-210.93 12,-210.93 12,-210.93 103,-210.93 103,-210.93 109,-210.93 115,-216.93 115,-222.93 115,-222.93 115,-251.93 115,-251.93 115,-257.93 109,-263.93 103,-263.93\"/>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- obs -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>obs</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"182.5\" cy=\"-126.48\" rx=\"57.97\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-137.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-122.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-107.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n",
"</g>\n",
"<!-- x&#45;&gt;obs -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>x&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M86.81,-210.89C102.83,-196.92 122.94,-179.39 140.6,-164\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"143.27,-166.31 148.51,-157.1 138.67,-161.04 143.27,-166.31\"/>\n",
"</g>\n",
"<!-- b1 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>b1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"182.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1&#45;&gt;obs -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>b1&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-199.85C182.5,-191.67 182.5,-182.89 182.5,-174.37\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-174.15 182.5,-164.15 179,-174.15 186,-174.15\"/>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>y</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M228,-53C228,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,0 137,0 137,0 228,0 228,0 234,0 240,-6 240,-12 240,-12 240,-41 240,-41 240,-47 234,-53 228,-53\"/>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- b0 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>b0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"299.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b0&#45;&gt;obs -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>b0&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M268.76,-207.8C254.69,-194.7 237.84,-179.01 222.77,-164.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"224.78,-162.07 215.08,-157.82 220.01,-167.19 224.78,-162.07\"/>\n",
"</g>\n",
"<!-- obs&#45;&gt;y -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>obs&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-88.99C182.5,-80.58 182.5,-71.63 182.5,-63.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-63.01 182.5,-53.01 179,-63.01 186,-63.01\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7ff4b6ff4f10>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with pm.Model() as core_model:\n",
" x = pm.ConstantData(\"x\", 0.0)\n",
" y = pm.ConstantData(\"y\", 1)\n",
" \n",
" b0 = pm.Normal(\"b0\")\n",
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=())\n",
"\n",
" logit_p = b0 + b1 * x\n",
" obs = pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n",
" \n",
"core_model.to_graphviz()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b6928647",
"metadata": {},
"outputs": [],
"source": [
"_, memo = fgraph_from_model(core_model, inlined_views=True)\n",
"# The vectorize for FreeRV will create multiple unique value variables everytime it is called. \n",
"# The standard representation includes the each FreeRV as an output and wherever it is used in the graph\n",
"# Vectorize would therefore find it at least twice\n",
"core_fgraph = FunctionGraph(outputs=[memo[obs]], clone=False)\n",
"# pytensor.dprint(core_fgraph, print_type=False)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "6ee05fb4",
"metadata": {},
"outputs": [],
"source": [
"hyper_priors = (\n",
" model_free_rv(pm.Normal.dist(name=\"b1_mean\"), pt.scalar(name=\"b1_mean\"), None),\n",
" model_free_rv(pm.HalfNormal.dist(name=\"b1_std\"), pt.scalar(name=\"b1_std_log__\"), tr.log),\n",
")\n",
"replace = {\n",
" memo[b1].owner.inputs[0]: pm.Normal.dist(*hyper_priors, shape=(10,)),\n",
" memo[x].owner.inputs[0]: pt.as_tensor(np.arange(10, dtype=float)),\n",
" memo[y].owner.inputs[0]: pt.as_tensor(np.ones(10, dtype=int))\n",
"}\n",
"\n",
"vect_fgraph = FunctionGraph(\n",
" outputs=vectorize_graph(core_fgraph.outputs, replace=replace)\n",
")\n",
"MergeOptimizer().rewrite(vect_fgraph);\n",
"# pytensor.dprint(vect_fgraph, print_type=False)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "bcc32a32",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"450pt\" height=\"433pt\"\n",
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>y</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- x -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>x</title>\n",
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n",
"</g>\n",
"<!-- obs -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>obs</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n",
"</g>\n",
"<!-- x&#45;&gt;obs -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>x&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n",
"</g>\n",
"<!-- b1 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>b1</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1&#45;&gt;obs -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>b1&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n",
"</g>\n",
"<!-- obs&#45;&gt;y -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>obs&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n",
"</g>\n",
"<!-- b1_std -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>b1_std</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- b1_std&#45;&gt;b1 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>b1_std&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n",
"</g>\n",
"<!-- b1_mean -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>b1_mean</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b1_mean&#45;&gt;b1 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>b1_mean&#45;&gt;b1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n",
"</g>\n",
"<!-- b0 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>b0</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- b0&#45;&gt;obs -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>b0&#45;&gt;obs</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7ff4b485e8f0>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vect_model = model_from_fgraph(vect_fgraph)\n",
"vect_model.to_graphviz()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "1a30bd26",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('b0', -0.92),\n",
" ('b1', -9.19),\n",
" ('b1_mean', -0.92),\n",
" ('b1_std', -0.73),\n",
" ('obs', -6.93)]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(vect_model.point_logps().items())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "a44f6aae",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('b0', -0.92),\n",
" ('b1', -9.19),\n",
" ('b1_mean', -0.92),\n",
" ('b1_std', -0.73),\n",
" ('obs', -6.93)]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sorted(target_model.point_logps().items())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4695184",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment