Last active
November 18, 2023 10:09
-
-
Save ricardoV94/99c53fbb8b2e9a68e1b2c6c4d761eaf4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "fc9b5316", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pytensor\n", | |
"import pytensor.tensor as pt\n", | |
"from pytensor.graph.fg import FunctionGraph\n", | |
"from pytensor.graph.replace import vectorize_graph\n", | |
"from pytensor.graph.replace import _vectorize_node\n", | |
"from pytensor.graph.rewriting.basic import MergeOptimizer\n", | |
"\n", | |
"import numpy as np\n", | |
"import pymc as pm\n", | |
"\n", | |
"from pymc.distributions import transforms as tr\n", | |
"\n", | |
"from pymc.model.fgraph import (\n", | |
" fgraph_from_model,\n", | |
" model_from_fgraph,\n", | |
" ModelFreeRV,\n", | |
" ModelObservedRV,\n", | |
" ModelNamed,\n", | |
" model_free_rv,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "4acb13f5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@_vectorize_node.register(ModelNamed)\n", | |
"def vectorize_model_named(op, node, new_value):\n", | |
" name = node.outputs[0].name\n", | |
" new_value.name = name\n", | |
" new_model_named = op(new_value)\n", | |
" new_model_named.name = name \n", | |
" return new_model_named.owner\n", | |
"\n", | |
"@_vectorize_node.register(ModelObservedRV)\n", | |
"@_vectorize_node.register(ModelFreeRV)\n", | |
"def vectorize_model_free_rv(op, node, new_rv, new_value):\n", | |
" old_rv, old_value = node.inputs\n", | |
" \n", | |
"# print(node)\n", | |
" \n", | |
" # Check what changed\n", | |
" batch_dims = new_rv.ndim - old_rv.ndim\n", | |
" batch_shape = ()\n", | |
" if batch_dims:\n", | |
" batch_type_shape = new_rv.type.shape[:batch_dims]\n", | |
" batch_shape = new_rv.shape[:batch_dims]\n", | |
" \n", | |
" if new_value.ndim > old_value.ndim:\n", | |
" # Raise NotImplementdeError for partial batching \n", | |
" if new_value.ndim - old_value.ndim != batch_dims:\n", | |
" raise NotImplementedError()\n", | |
" else:\n", | |
" # Batch values\n", | |
" if isinstance(op, ModelFreeRV):\n", | |
" new_type_shape = tuple(batch_type_shape) + new_value.type.shape\n", | |
" new_value = new_value.type.clone(shape=new_type_shape)(old_value.name)\n", | |
" else:\n", | |
" new_shape = tuple(batch_shape) + tuple(new_value.shape)\n", | |
" new_value = pt.broadcast_to(new_value, new_shape)\n", | |
" \n", | |
" new_rv.name = old_rv.name\n", | |
" out = op(new_rv, new_value)\n", | |
" out.name = old_rv.name\n", | |
" return out.owner" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "fafba8e8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"450pt\" height=\"433pt\"\n", | |
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n", | |
"<g id=\"clust1\" class=\"cluster\">\n", | |
"<title>cluster10</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n", | |
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- x -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- obs -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>obs</title>\n", | |
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n", | |
"</g>\n", | |
"<!-- x->obs -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>x->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n", | |
"</g>\n", | |
"<!-- b1 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>b1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1->obs -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>b1->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n", | |
"</g>\n", | |
"<!-- obs->y -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>obs->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n", | |
"</g>\n", | |
"<!-- b1_std -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>b1_std</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n", | |
"</g>\n", | |
"<!-- b1_std->b1 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>b1_std->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n", | |
"</g>\n", | |
"<!-- b1_mean -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>b1_mean</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1_mean->b1 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>b1_mean->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n", | |
"</g>\n", | |
"<!-- b0 -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>b0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b0->obs -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>b0->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7ff4b7bbe620>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model() as target_model:\n", | |
" x = pm.ConstantData(\"x\", np.arange(10, dtype=float))\n", | |
" y = pm.ConstantData(\"y\", np.ones(10, dtype=int))\n", | |
" \n", | |
" b1_mean = pm.Normal(\"b1_mean\", 0, 1)\n", | |
" b1_std = pm.HalfNormal(\"b1_std\", 1)\n", | |
" b0 = pm.Normal(\"b0\")\n", | |
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=(10,))\n", | |
" \n", | |
" logit_p = b0 + b1 * x\n", | |
" pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n", | |
" \n", | |
"target_model.to_graphviz()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cc09069c", | |
"metadata": {}, | |
"source": [ | |
"## From a core model with hyperpriors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "d843a303", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"357pt\" height=\"394pt\"\n", | |
" viewBox=\"0.00 0.00 357.00 393.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 389.86)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-389.86 353,-389.86 353,4 -4,4\"/>\n", | |
"<!-- x -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M103,-263.93C103,-263.93 12,-263.93 12,-263.93 6,-263.93 0,-257.93 0,-251.93 0,-251.93 0,-222.93 0,-222.93 0,-216.93 6,-210.93 12,-210.93 12,-210.93 103,-210.93 103,-210.93 109,-210.93 115,-216.93 115,-222.93 115,-222.93 115,-251.93 115,-251.93 115,-257.93 109,-263.93 103,-263.93\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- obs -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>obs</title>\n", | |
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"182.5\" cy=\"-126.48\" rx=\"57.97\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-137.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-122.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-107.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n", | |
"</g>\n", | |
"<!-- x->obs -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>x->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M86.81,-210.89C102.83,-196.92 122.94,-179.39 140.6,-164\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"143.27,-166.31 148.51,-157.1 138.67,-161.04 143.27,-166.31\"/>\n", | |
"</g>\n", | |
"<!-- b1 -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>b1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"182.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1->obs -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>b1->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-199.85C182.5,-191.67 182.5,-182.89 182.5,-174.37\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-174.15 182.5,-164.15 179,-174.15 186,-174.15\"/>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M228,-53C228,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,0 137,0 137,0 228,0 228,0 234,0 240,-6 240,-12 240,-12 240,-41 240,-41 240,-47 234,-53 228,-53\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- b0 -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>b0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"299.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b0->obs -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>b0->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M268.76,-207.8C254.69,-194.7 237.84,-179.01 222.77,-164.98\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"224.78,-162.07 215.08,-157.82 220.01,-167.19 224.78,-162.07\"/>\n", | |
"</g>\n", | |
"<!-- b1_std -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>b1_std</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"110.5\" cy=\"-348.38\" rx=\"70.92\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-359.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n", | |
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-344.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"110.5\" y=\"-329.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n", | |
"</g>\n", | |
"<!-- b1_std->b1 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>b1_std->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M133.48,-312.61C140.36,-302.19 147.98,-290.66 155.12,-279.86\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"158.13,-281.66 160.72,-271.39 152.29,-277.8 158.13,-281.66\"/>\n", | |
"</g>\n", | |
"<!-- b1_mean -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>b1_mean</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"255.5\" cy=\"-348.38\" rx=\"56.64\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-359.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n", | |
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-344.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"255.5\" y=\"-329.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1_mean->b1 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>b1_mean->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M233,-313.8C225.79,-303.04 217.72,-291 210.18,-279.75\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"213.02,-277.7 204.55,-271.34 207.21,-281.59 213.02,-277.7\"/>\n", | |
"</g>\n", | |
"<!-- obs->y -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>obs->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-88.99C182.5,-80.58 182.5,-71.63 182.5,-63.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-63.01 182.5,-53.01 179,-63.01 186,-63.01\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7ff4b77dfd00>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model() as core_model:\n", | |
" x = pm.ConstantData(\"x\", 0.0)\n", | |
" y = pm.ConstantData(\"y\", 1)\n", | |
" \n", | |
" b1_mean = pm.Normal(\"b1_mean\", 0, 1)\n", | |
" b1_std = pm.HalfNormal(\"b1_std\", 1)\n", | |
" \n", | |
" b0 = pm.Normal(\"b0\")\n", | |
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=())\n", | |
"\n", | |
" \n", | |
" logit_p = b0 + b1 * x\n", | |
" obs = pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n", | |
" \n", | |
"core_model.to_graphviz()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "f7d1cd8c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ModelObservedRV{transform=None} [id A] 'obs' 15\n", | |
" ├─ bernoulli_rv{0, (0,), int64, False}.1 [id B] 'obs' 14\n", | |
" │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8580>) [id C]\n", | |
" │ ├─ [] [id D]\n", | |
" │ ├─ 4 [id E]\n", | |
" │ └─ Sigmoid [id F] 13\n", | |
" │ └─ Add [id G] 12\n", | |
" │ ├─ ModelFreeRV{transform=None} [id H] 'b0' 11\n", | |
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id I] 'b0' 10\n", | |
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9700>) [id J]\n", | |
" │ │ │ ├─ [] [id K]\n", | |
" │ │ │ ├─ 11 [id L]\n", | |
" │ │ │ ├─ 0 [id M]\n", | |
" │ │ │ └─ 1.0 [id N]\n", | |
" │ │ └─ b0 [id O]\n", | |
" │ └─ Mul [id P] 9\n", | |
" │ ├─ ModelFreeRV{transform=None} [id Q] 'b1' 8\n", | |
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id R] 'b1' 7\n", | |
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A92A0>) [id S]\n", | |
" │ │ │ ├─ [] [id T]\n", | |
" │ │ │ ├─ 11 [id U]\n", | |
" │ │ │ ├─ ModelFreeRV{transform=None} [id V] 'b1_mean' 6\n", | |
" │ │ │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id W] 'b1_mean' 5\n", | |
" │ │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9A80>) [id X]\n", | |
" │ │ │ │ │ ├─ [] [id Y]\n", | |
" │ │ │ │ │ ├─ 11 [id Z]\n", | |
" │ │ │ │ │ ├─ 0 [id BA]\n", | |
" │ │ │ │ │ └─ 1.0 [id BB]\n", | |
" │ │ │ │ └─ b1_mean [id BC]\n", | |
" │ │ │ └─ ModelFreeRV{transform=<pymc.logprob.transforms.LogTransform object at 0x7ff4c45ba6b0>} [id BD] 'b1_std' 4\n", | |
" │ │ │ ├─ halfnormal_rv{0, (0, 0), floatX, False}.1 [id BE] 'b1_std' 3\n", | |
" │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9540>) [id BF]\n", | |
" │ │ │ │ ├─ [] [id BG]\n", | |
" │ │ │ │ ├─ 11 [id BH]\n", | |
" │ │ │ │ ├─ 0.0 [id BI]\n", | |
" │ │ │ │ └─ 1.0 [id BJ]\n", | |
" │ │ │ └─ b1_std_log__ [id BK]\n", | |
" │ │ └─ b1 [id BL]\n", | |
" │ └─ ModelNamed [id BM] 'x' 2\n", | |
" │ └─ x{0.0} [id BN]\n", | |
" └─ Cast{int64} [id BO] 1\n", | |
" └─ ModelNamed [id BP] 'y' 0\n", | |
" └─ y{1.0} [id BQ]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<ipykernel.iostream.OutStream at 0x7ff53cef9c60>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"_, memo = fgraph_from_model(core_model, inlined_views=True)\n", | |
"# The vectorize for FreeRV will create multiple unique value variables everytime it is called. \n", | |
"# The standard representation includes the each FreeRV as an output and wherever it is used in the graph\n", | |
"# Vectorize would therefore find it at least twice\n", | |
"core_fgraph = FunctionGraph(outputs=[memo[obs]], clone=False)\n", | |
"pytensor.dprint(core_fgraph, print_type=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "84222351", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ModelObservedRV{transform=None} [id A] 'obs' 16\n", | |
" ├─ bernoulli_rv{0, (0,), int64, False}.1 [id B] 'obs' 15\n", | |
" │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8580>) [id C]\n", | |
" │ ├─ [] [id D]\n", | |
" │ ├─ 4 [id E]\n", | |
" │ └─ Sigmoid [id F] 14\n", | |
" │ └─ Add [id G] 13\n", | |
" │ ├─ ExpandDims{axis=0} [id H] 12\n", | |
" │ │ └─ ModelFreeRV{transform=None} [id I] 'b0' 11\n", | |
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id J] 'b0' 10\n", | |
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9700>) [id K]\n", | |
" │ │ │ ├─ [] [id D]\n", | |
" │ │ │ ├─ 11 [id L]\n", | |
" │ │ │ ├─ 0 [id M]\n", | |
" │ │ │ └─ 1.0 [id N]\n", | |
" │ │ └─ b0 [id O]\n", | |
" │ └─ Mul [id P] 9\n", | |
" │ ├─ ModelFreeRV{transform=None} [id Q] 'b1' 8\n", | |
" │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id R] 'b1' 7\n", | |
" │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A8820>) [id S]\n", | |
" │ │ │ ├─ [10] [id T]\n", | |
" │ │ │ ├─ 11 [id L]\n", | |
" │ │ │ ├─ ModelFreeRV{transform=None} [id U] 'b1_mean' 6\n", | |
" │ │ │ │ ├─ normal_rv{0, (0, 0), floatX, False}.1 [id V] 'b1_mean' 5\n", | |
" │ │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9A80>) [id W]\n", | |
" │ │ │ │ │ ├─ [] [id D]\n", | |
" │ │ │ │ │ ├─ 11 [id L]\n", | |
" │ │ │ │ │ ├─ 0 [id M]\n", | |
" │ │ │ │ │ └─ 1.0 [id N]\n", | |
" │ │ │ │ └─ b1_mean [id X]\n", | |
" │ │ │ └─ ModelFreeRV{transform=<pymc.logprob.transforms.LogTransform object at 0x7ff4c45ba6b0>} [id Y] 'b1_std' 4\n", | |
" │ │ │ ├─ halfnormal_rv{0, (0, 0), floatX, False}.1 [id Z] 'b1_std' 3\n", | |
" │ │ │ │ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FF4B77A9540>) [id BA]\n", | |
" │ │ │ │ ├─ [] [id D]\n", | |
" │ │ │ │ ├─ 11 [id L]\n", | |
" │ │ │ │ ├─ 0.0 [id BB]\n", | |
" │ │ │ │ └─ 1.0 [id N]\n", | |
" │ │ │ └─ b1_std_log__ [id BC]\n", | |
" │ │ └─ b1 [id BD]\n", | |
" │ └─ ModelNamed [id BE] 'x' 2\n", | |
" │ └─ x{[0. 1. 2. ... 7. 8. 9.]} [id BF]\n", | |
" └─ Cast{int64} [id BG] 1\n", | |
" └─ ModelNamed [id BH] 'y' 0\n", | |
" └─ y{[1 1 1 1 1 ... 1 1 1 1 1]} [id BI]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<ipykernel.iostream.OutStream at 0x7ff53cef9c60>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"replace = {\n", | |
" memo[b1].owner.inputs[0]: pm.Normal.dist(memo[b1_mean], memo[b1_std], shape=(10,)),\n", | |
" memo[x].owner.inputs[0]: pt.as_tensor(np.arange(10, dtype=float)),\n", | |
" memo[y].owner.inputs[0]: pt.as_tensor(np.ones(10, dtype=int))\n", | |
"}\n", | |
"\n", | |
"vect_fgraph = FunctionGraph(\n", | |
" outputs=vectorize_graph(core_fgraph.outputs, replace=replace)\n", | |
")\n", | |
"MergeOptimizer().rewrite(vect_fgraph)\n", | |
"pytensor.dprint(vect_fgraph, print_type=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "1ecf5dbf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"450pt\" height=\"433pt\"\n", | |
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n", | |
"<g id=\"clust1\" class=\"cluster\">\n", | |
"<title>cluster10</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n", | |
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- x -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- obs -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>obs</title>\n", | |
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n", | |
"</g>\n", | |
"<!-- x->obs -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>x->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n", | |
"</g>\n", | |
"<!-- b1 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>b1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1->obs -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>b1->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n", | |
"</g>\n", | |
"<!-- obs->y -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>obs->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n", | |
"</g>\n", | |
"<!-- b1_std -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>b1_std</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n", | |
"</g>\n", | |
"<!-- b1_std->b1 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>b1_std->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n", | |
"</g>\n", | |
"<!-- b1_mean -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>b1_mean</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1_mean->b1 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>b1_mean->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n", | |
"</g>\n", | |
"<!-- b0 -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>b0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b0->obs -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>b0->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7ff4b777a2c0>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"vect_model = model_from_fgraph(vect_fgraph)\n", | |
"vect_model.to_graphviz()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "cb81b3b2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('b0', -0.92),\n", | |
" ('b1', -9.19),\n", | |
" ('b1_mean', -0.92),\n", | |
" ('b1_std', -0.73),\n", | |
" ('obs', -6.93)]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(vect_model.point_logps().items())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "35aeb105", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('b0', -0.92),\n", | |
" ('b1', -9.19),\n", | |
" ('b1_mean', -0.92),\n", | |
" ('b1_std', -0.73),\n", | |
" ('obs', -6.93)]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(target_model.point_logps().items())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "88546c7d", | |
"metadata": {}, | |
"source": [ | |
"## From a core model without hyperpriors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "dcf48a7e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"357pt\" height=\"283pt\"\n", | |
" viewBox=\"0.00 0.00 357.00 282.91\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 278.91)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-278.91 353,-278.91 353,4 -4,4\"/>\n", | |
"<!-- x -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M103,-263.93C103,-263.93 12,-263.93 12,-263.93 6,-263.93 0,-257.93 0,-251.93 0,-251.93 0,-222.93 0,-222.93 0,-216.93 6,-210.93 12,-210.93 12,-210.93 103,-210.93 103,-210.93 109,-210.93 115,-216.93 115,-222.93 115,-222.93 115,-251.93 115,-251.93 115,-257.93 109,-263.93 103,-263.93\"/>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"57.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- obs -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>obs</title>\n", | |
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"182.5\" cy=\"-126.48\" rx=\"57.97\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-137.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-122.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-107.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n", | |
"</g>\n", | |
"<!-- x->obs -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>x->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M86.81,-210.89C102.83,-196.92 122.94,-179.39 140.6,-164\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"143.27,-166.31 148.51,-157.1 138.67,-161.04 143.27,-166.31\"/>\n", | |
"</g>\n", | |
"<!-- b1 -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>b1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"182.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1->obs -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>b1->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-199.85C182.5,-191.67 182.5,-182.89 182.5,-174.37\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-174.15 182.5,-164.15 179,-174.15 186,-174.15\"/>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M228,-53C228,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,0 137,0 137,0 228,0 228,0 234,0 240,-6 240,-12 240,-12 240,-41 240,-41 240,-47 234,-53 228,-53\"/>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"182.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- b0 -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>b0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"299.5\" cy=\"-237.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-248.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-233.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"299.5\" y=\"-218.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b0->obs -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>b0->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M268.76,-207.8C254.69,-194.7 237.84,-179.01 222.77,-164.98\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"224.78,-162.07 215.08,-157.82 220.01,-167.19 224.78,-162.07\"/>\n", | |
"</g>\n", | |
"<!-- obs->y -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>obs->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M182.5,-88.99C182.5,-80.58 182.5,-71.63 182.5,-63.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"186,-63.01 182.5,-53.01 179,-63.01 186,-63.01\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7ff4b6ff4f10>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"with pm.Model() as core_model:\n", | |
" x = pm.ConstantData(\"x\", 0.0)\n", | |
" y = pm.ConstantData(\"y\", 1)\n", | |
" \n", | |
" b0 = pm.Normal(\"b0\")\n", | |
" b1 = pm.Normal(\"b1\", b1_mean, b1_std, shape=())\n", | |
"\n", | |
" logit_p = b0 + b1 * x\n", | |
" obs = pm.Bernoulli(\"obs\", logit_p=logit_p, observed=y)\n", | |
" \n", | |
"core_model.to_graphviz()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "b6928647", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"_, memo = fgraph_from_model(core_model, inlined_views=True)\n", | |
"# The vectorize for FreeRV will create multiple unique value variables everytime it is called. \n", | |
"# The standard representation includes the each FreeRV as an output and wherever it is used in the graph\n", | |
"# Vectorize would therefore find it at least twice\n", | |
"core_fgraph = FunctionGraph(outputs=[memo[obs]], clone=False)\n", | |
"# pytensor.dprint(core_fgraph, print_type=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "6ee05fb4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hyper_priors = (\n", | |
" model_free_rv(pm.Normal.dist(name=\"b1_mean\"), pt.scalar(name=\"b1_mean\"), None),\n", | |
" model_free_rv(pm.HalfNormal.dist(name=\"b1_std\"), pt.scalar(name=\"b1_std_log__\"), tr.log),\n", | |
")\n", | |
"replace = {\n", | |
" memo[b1].owner.inputs[0]: pm.Normal.dist(*hyper_priors, shape=(10,)),\n", | |
" memo[x].owner.inputs[0]: pt.as_tensor(np.arange(10, dtype=float)),\n", | |
" memo[y].owner.inputs[0]: pt.as_tensor(np.ones(10, dtype=int))\n", | |
"}\n", | |
"\n", | |
"vect_fgraph = FunctionGraph(\n", | |
" outputs=vectorize_graph(core_fgraph.outputs, replace=replace)\n", | |
")\n", | |
"MergeOptimizer().rewrite(vect_fgraph);\n", | |
"# pytensor.dprint(vect_fgraph, print_type=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "bcc32a32", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 3.0.0 (20220315.2325)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"450pt\" height=\"433pt\"\n", | |
" viewBox=\"0.00 0.00 450.21 432.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 428.86)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-428.86 446.21,-428.86 446.21,4 -4,4\"/>\n", | |
"<g id=\"clust1\" class=\"cluster\">\n", | |
"<title>cluster10</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M97.71,-8C97.71,-8 321.71,-8 321.71,-8 327.71,-8 333.71,-14 333.71,-20 333.71,-20 333.71,-309.91 333.71,-309.91 333.71,-315.91 327.71,-321.91 321.71,-321.91 321.71,-321.91 97.71,-321.91 97.71,-321.91 91.71,-321.91 85.71,-315.91 85.71,-309.91 85.71,-309.91 85.71,-20 85.71,-20 85.71,-14 91.71,-8 97.71,-8\"/>\n", | |
"<text text-anchor=\"middle\" x=\"316.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-92C313.21,-92 222.21,-92 222.21,-92 216.21,-92 210.21,-86 210.21,-80 210.21,-80 210.21,-51 210.21,-51 210.21,-45 216.21,-39 222.21,-39 222.21,-39 313.21,-39 313.21,-39 319.21,-39 325.21,-45 325.21,-51 325.21,-51 325.21,-80 325.21,-80 325.21,-86 319.21,-92 313.21,-92\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-76.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-61.8\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- x -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<path fill=\"lightgrey\" stroke=\"black\" d=\"M313.21,-302.93C313.21,-302.93 222.21,-302.93 222.21,-302.93 216.21,-302.93 210.21,-296.93 210.21,-290.93 210.21,-290.93 210.21,-261.93 210.21,-261.93 210.21,-255.93 216.21,-249.93 222.21,-249.93 222.21,-249.93 313.21,-249.93 313.21,-249.93 319.21,-249.93 325.21,-255.93 325.21,-261.93 325.21,-261.93 325.21,-290.93 325.21,-290.93 325.21,-296.93 319.21,-302.93 313.21,-302.93\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">ConstantData</text>\n", | |
"</g>\n", | |
"<!-- obs -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>obs</title>\n", | |
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"267.71\" cy=\"-165.48\" rx=\"57.97\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-176.78\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-161.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"267.71\" y=\"-146.78\" font-family=\"Times,serif\" font-size=\"14.00\">Bernoulli</text>\n", | |
"</g>\n", | |
"<!-- x->obs -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>x->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-249.89C267.71,-238.98 267.71,-225.89 267.71,-213.35\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-213 267.71,-203 264.21,-213 271.21,-213\"/>\n", | |
"</g>\n", | |
"<!-- b1 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>b1</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"142.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b1</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"142.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1->obs -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>b1->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M174.58,-247.65C190.26,-233.99 209.36,-217.34 226.17,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"228.47,-205.32 233.71,-196.11 223.87,-200.04 228.47,-205.32\"/>\n", | |
"</g>\n", | |
"<!-- obs->y -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>obs->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M267.71,-127.99C267.71,-119.58 267.71,-110.63 267.71,-102.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"271.21,-102.01 267.71,-92.01 264.21,-102.01 271.21,-102.01\"/>\n", | |
"</g>\n", | |
"<!-- b1_std -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>b1_std</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-387.38\" rx=\"70.92\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_std</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n", | |
"</g>\n", | |
"<!-- b1_std->b1 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>b1_std->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M93.69,-351.61C100.57,-341.19 108.19,-329.66 115.33,-318.86\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"118.34,-320.66 120.93,-310.39 112.5,-316.8 118.34,-320.66\"/>\n", | |
"</g>\n", | |
"<!-- b1_mean -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>b1_mean</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"215.71\" cy=\"-387.38\" rx=\"56.64\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-398.68\" font-family=\"Times,serif\" font-size=\"14.00\">b1_mean</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-383.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"215.71\" y=\"-368.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b1_mean->b1 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>b1_mean->b1</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M193.21,-352.8C186,-342.04 177.93,-330 170.39,-318.75\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"173.23,-316.7 164.76,-310.34 167.42,-320.59 173.23,-316.7\"/>\n", | |
"</g>\n", | |
"<!-- b0 -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>b0</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"392.71\" cy=\"-276.43\" rx=\"49.49\" ry=\"37.45\"/>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-287.73\" font-family=\"Times,serif\" font-size=\"14.00\">b0</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-272.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n", | |
"<text text-anchor=\"middle\" x=\"392.71\" y=\"-257.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n", | |
"</g>\n", | |
"<!-- b0->obs -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>b0->obs</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M360.84,-247.65C345.16,-233.99 326.06,-217.34 309.25,-202.68\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"311.55,-200.04 301.71,-196.11 306.95,-205.32 311.55,-200.04\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7ff4b485e8f0>" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"vect_model = model_from_fgraph(vect_fgraph)\n", | |
"vect_model.to_graphviz()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "1a30bd26", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('b0', -0.92),\n", | |
" ('b1', -9.19),\n", | |
" ('b1_mean', -0.92),\n", | |
" ('b1_std', -0.73),\n", | |
" ('obs', -6.93)]" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(vect_model.point_logps().items())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "a44f6aae", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('b0', -0.92),\n", | |
" ('b1', -9.19),\n", | |
" ('b1_mean', -0.92),\n", | |
" ('b1_std', -0.73),\n", | |
" ('obs', -6.93)]" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sorted(target_model.point_logps().items())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c4695184", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "pymc", | |
"language": "python", | |
"name": "pymc" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.8" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment