Last active
June 15, 2023 09:33
-
-
Save ricardoV94/ae4c51365b871713bc3cca735fe8fa2f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "0a62ff8d", | |
"metadata": {}, | |
"source": [ | |
"## Narrow-down problem" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "5111df4a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from itertools import chain\n", | |
"\n", | |
"import pytensor\n", | |
"import pytensor.tensor as pt\n", | |
"import pymc as pm\n", | |
"import numpy as np\n", | |
"\n", | |
"from pytensor.graph.basic import Variable, Constant, graph_inputs, io_toposort\n", | |
"\n", | |
"\n", | |
"N, M = 2, 2\n", | |
"\n", | |
"with pm.Model(check_bounds=False) as model:\n", | |
" pm.ZeroSumNormal(\"x\", n_zerosum_axes=1, shape=(N, M))\n", | |
"\n", | |
"graph = [model.dlogp()]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "48cf38c8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mode1 = \"FAST_COMPILE\"\n", | |
"mode2 = \"NUMBA\"\n", | |
"test_values = [np.ones((N, M-1))]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "fabd4c39", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def makeiter(a):\n", | |
" if not isinstance(a, (list, tuple)):\n", | |
" return [a]\n", | |
" return a\n", | |
"\n", | |
"def inputvars(a, blockers=None):\n", | |
" return [\n", | |
" v\n", | |
" for v in graph_inputs(makeiter(a), blockers=blockers)\n", | |
" if isinstance(v, Variable) and not isinstance(v, Constant)\n", | |
" ]\n", | |
"\n", | |
"def check_equivalence(\n", | |
" graph_inputs,\n", | |
" graph_outputs, \n", | |
" mode1, \n", | |
" mode2, \n", | |
" test_values, \n", | |
" dprint=False,\n", | |
"):\n", | |
" fn1 = pytensor.function(graph_inputs, graph_outputs, mode=mode1, on_unused_input=\"ignore\")\n", | |
" fn2 = pytensor.function(graph_inputs, graph_outputs, mode=mode2, on_unused_input=\"ignore\")\n", | |
"\n", | |
" if dprint:\n", | |
" pytensor.dprint(fn2, print_fgraph_inputs=True)\n", | |
" \n", | |
" res1 = makeiter(fn1(*test_values))\n", | |
" res2 = makeiter(fn2(*test_values))\n", | |
"\n", | |
" for r1, r2 in zip(res1, res2):\n", | |
" np.testing.assert_allclose(r1, r2, rtol=1e-6, atol=1e-6)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "24d1da89", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 [Reshape{1}.0]\n", | |
"1 [(d__logp/dx_zerosum__)]\n", | |
"2 [Split{2}.0, Second.0]\n", | |
"Graph is equivalent now\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = inputvars(graph)\n", | |
"curr_graph = graph\n", | |
"\n", | |
"for i in range(100): \n", | |
" print(i, curr_graph)\n", | |
" try:\n", | |
" check_equivalence(inputs, curr_graph, mode1, mode2, test_values)\n", | |
" except AssertionError:\n", | |
" # Check if the varibale inputs to the current graph outputs already differ\n", | |
" curr_graph = list(set(chain.from_iterable(\n", | |
" [i for i in v.owner.inputs if not isinstance(i, Constant)]\n", | |
" for v in curr_graph if v.owner\n", | |
" )))\n", | |
" else:\n", | |
" print(\"Graph is equivalent now\")\n", | |
" break " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "88a49c2c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Problematic graph should be from these inputs to original output\n", | |
"nw_curr_graph = list(set(chain.from_iterable(\n", | |
" [i for i in v.owner.inputs if not isinstance(i, Constant)]\n", | |
" for v in curr_graph if v.owner\n", | |
")))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "227ae674", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([[1.],\n", | |
" [1.]]),\n", | |
" array([[-0.14644661],\n", | |
" [-0.14644661]]),\n", | |
" array([[-0.35355339, 0.35355339],\n", | |
" [-0.35355339, 0.35355339]]),\n", | |
" array([1, 1])]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nw_test_values = pytensor.function(inputs, nw_curr_graph)(*test_values)\n", | |
"nw_test_values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "5d2176be", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Reshape{1} [id A] 2\n", | |
" ├─ Add [id B] '(d__logp/dx_zerosum__)' 1\n", | |
" │ ├─ Split{2}.0 [id C] 0\n", | |
" │ │ ├─ <Matrix(float64, shape=(?, ?))> [id D]\n", | |
" │ │ ├─ 1 [id E]\n", | |
" │ │ └─ <Vector(int64, shape=(2,))> [id F]\n", | |
" │ └─ <Matrix(float64, shape=(?, 1))> [id G]\n", | |
" └─ [-1] [id H]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<ipykernel.iostream.OutStream at 0x7ffa1c356aa0>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fn1 = pytensor.function(nw_curr_graph, graph, mode=mode1)\n", | |
"fn2 = pytensor.function(nw_curr_graph, graph, mode=mode2)\n", | |
"pytensor.dprint(fn2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "4507d1f4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([-0.5, -0.5])]\n", | |
"[array([-0.5, -0.5])]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(fn1(*nw_test_values))\n", | |
"print(fn1(*nw_test_values))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "a1342409", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([-0.5, -0.5])]\n", | |
"[array([-0.64644661, -0.64644661])]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(fn2(*nw_test_values))\n", | |
"print(fn2(*nw_test_values))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "68d29f66", | |
"metadata": {}, | |
"source": [ | |
"## Recreate issue manually" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "ba98dc6b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Add [id A] <Matrix(float64, shape=(?, ?))> 1\n", | |
" ├─ Split{2}.0 [id B] <Matrix(float64, shape=(?, ?))> 0\n", | |
" │ ├─ x1 [id C] <Matrix(float64, shape=(?, ?))>\n", | |
" │ ├─ 1 [id D] <Scalar(int8, shape=())>\n", | |
" │ └─ v [id E] <Vector(int64, shape=(2,))>\n", | |
" └─ x2 [id F] <Matrix(float64, shape=(?, 1))>\n", | |
"[[-0.06889045]\n", | |
" [ 1.86502905]]\n", | |
"[[0.85134045]\n", | |
" [2.44213284]]\n", | |
"[[-0.98912135 -0.36778665]\n", | |
" [ 1.28792526 0.19397442]]\n", | |
"[[ 0.85134045 -0.36778665]\n", | |
" [ 2.44213284 0.19397442]]\n" | |
] | |
} | |
], | |
"source": [ | |
"import pytensor\n", | |
"import pytensor.tensor as pt\n", | |
"import numpy as np\n", | |
"\n", | |
"x1 = pt.matrix(\"x1\")\n", | |
"x2 = pt.matrix(\"x2\", shape=(None, 1))\n", | |
"v = pt.vector(\"v\", shape=(2,), dtype=int)\n", | |
"out = pt.split(x1, v, n_splits=2, axis=1)[0] + x2\n", | |
"\n", | |
"fn = pytensor.function([x1, x2, v], out, \"NUMBA\")\n", | |
"pytensor.dprint(fn, print_type=True)\n", | |
"\n", | |
"# Add [id A] <Matrix(float64, shape=(?, ?))> 1\n", | |
"# ├─ Split{2}.0 [id B] <Matrix(float64, shape=(?, ?))> 0\n", | |
"# │ ├─ x1 [id C] <Matrix(float64, shape=(?, ?))>\n", | |
"# │ ├─ 1 [id D] <Scalar(int8, shape=())>\n", | |
"# │ └─ v [id E] <Vector(int64, shape=(2,))>\n", | |
"# └─ x2 [id F] <Matrix(float64, shape=(?, 1))>\n", | |
"\n", | |
"rng = np.random.default_rng(123)\n", | |
"test_x1 = rng.normal(size=(2, 2))\n", | |
"test_x1_copy = test_x1.copy()\n", | |
"test_x2 = rng.normal(size=(2, 1))\n", | |
"test_v = np.array([1, 1])\n", | |
"\n", | |
"print(fn(test_x1, test_x2, test_v)) # [[-0.06889045], [ 1.86502905]]\n", | |
"print(fn(test_x1, test_x2, test_v)) # [[0.85134045], [2.44213284]]\n", | |
"print(test_x1_copy) # [[-0.98912135 -0.36778665], [ 1.28792526 0.19397442]]\n", | |
"print(test_x1) # [[ 0.85134045 -0.36778665] [ 2.44213284 0.19397442]] " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "75dacc24", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "pymc", | |
"language": "python", | |
"name": "pymc" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.8" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment