Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ricardoV94/ae4c51365b871713bc3cca735fe8fa2f to your computer and use it in GitHub Desktop.
Save ricardoV94/ae4c51365b871713bc3cca735fe8fa2f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "0a62ff8d",
"metadata": {},
"source": [
"## Narrow-down problem"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5111df4a",
"metadata": {},
"outputs": [],
"source": [
"from itertools import chain\n",
"\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"import pymc as pm\n",
"import numpy as np\n",
"\n",
"from pytensor.graph.basic import Variable, Constant, graph_inputs, io_toposort\n",
"\n",
"\n",
"N, M = 2, 2\n",
"\n",
"with pm.Model(check_bounds=False) as model:\n",
" pm.ZeroSumNormal(\"x\", n_zerosum_axes=1, shape=(N, M))\n",
"\n",
"graph = [model.dlogp()]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48cf38c8",
"metadata": {},
"outputs": [],
"source": [
"mode1 = \"FAST_COMPILE\"\n",
"mode2 = \"NUMBA\"\n",
"test_values = [np.ones((N, M-1))]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fabd4c39",
"metadata": {},
"outputs": [],
"source": [
"def makeiter(a):\n",
" if not isinstance(a, (list, tuple)):\n",
" return [a]\n",
" return a\n",
"\n",
"def inputvars(a, blockers=None):\n",
" return [\n",
" v\n",
" for v in graph_inputs(makeiter(a), blockers=blockers)\n",
" if isinstance(v, Variable) and not isinstance(v, Constant)\n",
" ]\n",
"\n",
"def check_equivalence(\n",
" graph_inputs,\n",
" graph_outputs, \n",
" mode1, \n",
" mode2, \n",
" test_values, \n",
" dprint=False,\n",
"):\n",
" fn1 = pytensor.function(graph_inputs, graph_outputs, mode=mode1, on_unused_input=\"ignore\")\n",
" fn2 = pytensor.function(graph_inputs, graph_outputs, mode=mode2, on_unused_input=\"ignore\")\n",
"\n",
" if dprint:\n",
" pytensor.dprint(fn2, print_fgraph_inputs=True)\n",
" \n",
" res1 = makeiter(fn1(*test_values))\n",
" res2 = makeiter(fn2(*test_values))\n",
"\n",
" for r1, r2 in zip(res1, res2):\n",
" np.testing.assert_allclose(r1, r2, rtol=1e-6, atol=1e-6)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "24d1da89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 [Reshape{1}.0]\n",
"1 [(d__logp/dx_zerosum__)]\n",
"2 [Split{2}.0, Second.0]\n",
"Graph is equivalent now\n"
]
}
],
"source": [
"inputs = inputvars(graph)\n",
"curr_graph = graph\n",
"\n",
"for i in range(100): \n",
" print(i, curr_graph)\n",
" try:\n",
" check_equivalence(inputs, curr_graph, mode1, mode2, test_values)\n",
" except AssertionError:\n",
" # Check if the varibale inputs to the current graph outputs already differ\n",
" curr_graph = list(set(chain.from_iterable(\n",
" [i for i in v.owner.inputs if not isinstance(i, Constant)]\n",
" for v in curr_graph if v.owner\n",
" )))\n",
" else:\n",
" print(\"Graph is equivalent now\")\n",
" break "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "88a49c2c",
"metadata": {},
"outputs": [],
"source": [
"# Problematic graph should be from these inputs to original output\n",
"nw_curr_graph = list(set(chain.from_iterable(\n",
" [i for i in v.owner.inputs if not isinstance(i, Constant)]\n",
" for v in curr_graph if v.owner\n",
")))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "227ae674",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[1.],\n",
" [1.]]),\n",
" array([[-0.14644661],\n",
" [-0.14644661]]),\n",
" array([[-0.35355339, 0.35355339],\n",
" [-0.35355339, 0.35355339]]),\n",
" array([1, 1])]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nw_test_values = pytensor.function(inputs, nw_curr_graph)(*test_values)\n",
"nw_test_values"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5d2176be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reshape{1} [id A] 2\n",
" ├─ Add [id B] '(d__logp/dx_zerosum__)' 1\n",
" │ ├─ Split{2}.0 [id C] 0\n",
" │ │ ├─ <Matrix(float64, shape=(?, ?))> [id D]\n",
" │ │ ├─ 1 [id E]\n",
" │ │ └─ <Vector(int64, shape=(2,))> [id F]\n",
" │ └─ <Matrix(float64, shape=(?, 1))> [id G]\n",
" └─ [-1] [id H]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7ffa1c356aa0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fn1 = pytensor.function(nw_curr_graph, graph, mode=mode1)\n",
"fn2 = pytensor.function(nw_curr_graph, graph, mode=mode2)\n",
"pytensor.dprint(fn2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4507d1f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array([-0.5, -0.5])]\n",
"[array([-0.5, -0.5])]\n"
]
}
],
"source": [
"print(fn1(*nw_test_values))\n",
"print(fn1(*nw_test_values))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a1342409",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array([-0.5, -0.5])]\n",
"[array([-0.64644661, -0.64644661])]\n"
]
}
],
"source": [
"print(fn2(*nw_test_values))\n",
"print(fn2(*nw_test_values))"
]
},
{
"cell_type": "markdown",
"id": "68d29f66",
"metadata": {},
"source": [
"## Recreate issue manually"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ba98dc6b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Add [id A] <Matrix(float64, shape=(?, ?))> 1\n",
" ├─ Split{2}.0 [id B] <Matrix(float64, shape=(?, ?))> 0\n",
" │ ├─ x1 [id C] <Matrix(float64, shape=(?, ?))>\n",
" │ ├─ 1 [id D] <Scalar(int8, shape=())>\n",
" │ └─ v [id E] <Vector(int64, shape=(2,))>\n",
" └─ x2 [id F] <Matrix(float64, shape=(?, 1))>\n",
"[[-0.06889045]\n",
" [ 1.86502905]]\n",
"[[0.85134045]\n",
" [2.44213284]]\n",
"[[-0.98912135 -0.36778665]\n",
" [ 1.28792526 0.19397442]]\n",
"[[ 0.85134045 -0.36778665]\n",
" [ 2.44213284 0.19397442]]\n"
]
}
],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt\n",
"import numpy as np\n",
"\n",
"x1 = pt.matrix(\"x1\")\n",
"x2 = pt.matrix(\"x2\", shape=(None, 1))\n",
"v = pt.vector(\"v\", shape=(2,), dtype=int)\n",
"out = pt.split(x1, v, n_splits=2, axis=1)[0] + x2\n",
"\n",
"fn = pytensor.function([x1, x2, v], out, \"NUMBA\")\n",
"pytensor.dprint(fn, print_type=True)\n",
"\n",
"# Add [id A] <Matrix(float64, shape=(?, ?))> 1\n",
"# ├─ Split{2}.0 [id B] <Matrix(float64, shape=(?, ?))> 0\n",
"# │ ├─ x1 [id C] <Matrix(float64, shape=(?, ?))>\n",
"# │ ├─ 1 [id D] <Scalar(int8, shape=())>\n",
"# │ └─ v [id E] <Vector(int64, shape=(2,))>\n",
"# └─ x2 [id F] <Matrix(float64, shape=(?, 1))>\n",
"\n",
"rng = np.random.default_rng(123)\n",
"test_x1 = rng.normal(size=(2, 2))\n",
"test_x1_copy = test_x1.copy()\n",
"test_x2 = rng.normal(size=(2, 1))\n",
"test_v = np.array([1, 1])\n",
"\n",
"print(fn(test_x1, test_x2, test_v)) # [[-0.06889045], [ 1.86502905]]\n",
"print(fn(test_x1, test_x2, test_v)) # [[0.85134045], [2.44213284]]\n",
"print(test_x1_copy) # [[-0.98912135 -0.36778665], [ 1.28792526 0.19397442]]\n",
"print(test_x1) # [[ 0.85134045 -0.36778665] [ 2.44213284 0.19397442]] "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75dacc24",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment