Skip to content

Instantly share code, notes, and snippets.

@hellman
Last active March 17, 2024 18:37
Show Gist options
  • Save hellman/0d3792f4b85692a7d4386c95f7f86e38 to your computer and use it in GitHub Desktop.
Save hellman/0d3792f4b85692a7d4386c95f7f86e38 to your computer and use it in GitHub Desktop.
Decision Tree of a Boolean function in $O(n3^n)$
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c27199a3-32c4-442f-8074-8bfdd6416e05",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-17T18:29:39.221336Z",
"iopub.status.busy": "2024-03-17T18:29:39.220525Z",
"iopub.status.idle": "2024-03-17T18:29:39.241264Z",
"shell.execute_reply": "2024-03-17T18:29:39.238959Z",
"shell.execute_reply.started": "2024-03-17T18:29:39.221235Z"
}
},
"outputs": [],
"source": [
"from random import randrange\n",
"import functools\n",
"\n",
"def decision_tree():\n",
" ret = _decision_tree()\n",
" # cache only needed for internal memoization, free memory\n",
" _decision_tree.cache_clear()\n",
" return ret\n",
" \n",
"@functools.cache\n",
"def _decision_tree(mask=None, key=None):\n",
" # (mask, key) represent the current subfunction\n",
" # i.e. the restriction f|x : x & mask = key\n",
" # (restriction in the form of a pattern e.g. 01*00**1*110)\n",
" \n",
" # note: the function f is in globals to avoid slowing down the cache\n",
" # when passing it down\n",
" # (also the number of vars n is in globals)\n",
" # (demonstration only)\n",
" if mask is None and key is None:\n",
" key = mask = 0\n",
" \n",
" assert key & mask == key # representation\n",
" assert len(f) == 1 << n\n",
" \n",
" if mask == (1<<n) - 1:\n",
" return 0, f[key]\n",
"\n",
" depth = n + 1 # inf\n",
" tree = None\n",
" # try each possible variable to branch\n",
" # note: positions indexed lsb to MSB\n",
" for fork_pos in range(n):\n",
" bit = 1 << fork_pos\n",
" if mask & bit:\n",
" continue\n",
" \n",
" key0 = key\n",
" key1 = key | bit\n",
" mask_ = mask | bit\n",
" depth0, tree0 = _decision_tree(mask_, key0)\n",
" depth1, tree1 = _decision_tree(mask_, key1)\n",
" \n",
" # no need to fork?\n",
" if depth0 == depth1 == 0 and f[key0] == f[key1]:\n",
" return 0, f[key0]\n",
"\n",
" depth_ = max(depth0, depth1) + 1\n",
" if depth_ < depth: # TODO: randomize choice here\n",
" depth = depth_\n",
" tree = fork_pos, tree0, tree1\n",
" \n",
" assert tree is not None\n",
" return depth, tree\n",
"\n",
"\n",
"def tree_eval(tree, x):\n",
" # structure: (branch index i, subtree | x[i] = 0, subtree | x[i] = 1)\n",
" # or just 0/1 if the function is constant\n",
" if isinstance(tree, int):\n",
" return tree\n",
" pos, tree0, tree1 = tree\n",
" if x & (1 << pos):\n",
" return tree_eval(tree1, x)\n",
" return tree_eval(tree0, x)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c5eafa9a-c6fc-434e-a02c-7db071b27a12",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-17T18:29:39.661868Z",
"iopub.status.busy": "2024-03-17T18:29:39.658877Z",
"iopub.status.idle": "2024-03-17T18:34:13.819153Z",
"shell.execute_reply": "2024-03-17T18:34:13.817111Z",
"shell.execute_reply.started": "2024-03-17T18:29:39.661575Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n 1\n",
"143 ns ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n",
"n 2\n",
"901 ns ± 22.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"n 3\n",
"4.53 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"n 4\n",
"17.9 µs ± 97.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"n 5\n",
"75.3 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"n 6\n",
"308 µs ± 7.48 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
"n 7\n",
"1.44 ms ± 45.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
"n 8\n",
"5.87 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"n 9\n",
"23.8 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"n 10\n",
"86.5 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"n 11\n",
"308 ms ± 8.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"n 12\n",
"1.17 s ± 60.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for n in range(1, 13):\n",
" print(\"n\", n)\n",
" for itr in range(100):\n",
" f = [randrange(2) for _ in range(2**n)]\n",
" \n",
" depth, tree = decision_tree()\n",
" for x, y in enumerate(f):\n",
" assert tree_eval(tree, x) == y, x\n",
" %timeit decision_tree()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7cc27195-93a1-4c2a-895e-0bd0b2304ee9",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-17T18:36:58.168749Z",
"iopub.status.busy": "2024-03-17T18:36:58.165722Z",
"iopub.status.idle": "2024-03-17T18:36:58.188711Z",
"shell.execute_reply": "2024-03-17T18:36:58.186162Z",
"shell.execute_reply.started": "2024-03-17T18:36:58.168389Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(2, (1, 0, (2, 0, 1)))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = 3\n",
"f = [0] * 2**n\n",
"for i in range(3):\n",
" f[randrange(2**n)] = 1\n",
"decision_tree()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e3eb24de-9167-4503-84c8-06376200c037",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-17T18:37:00.999228Z",
"iopub.status.busy": "2024-03-17T18:37:00.997307Z",
"iopub.status.idle": "2024-03-17T18:37:01.014926Z",
"shell.execute_reply": "2024-03-17T18:37:01.012570Z",
"shell.execute_reply.started": "2024-03-17T18:37:00.998986Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[0, 0, 0, 0, 0, 0, 1, 1]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "PyPy3",
"language": "python",
"name": "pypy3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@hellman
Copy link
Author

hellman commented Mar 17, 2024

Ran on pypy3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment