Last active
June 5, 2017 17:01
-
-
Save cmcaine/a49054014cef6fb229ba619379b48a3f to your computer and use it in GitHub Desktop.
einsum-efficiency-problem.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"\n", | |
"def placeholder_data(P_inner_size, Q_inner_size):\n", | |
" \"Returns K, P, Q: random matrices with plausible values\"\n", | |
" Nr = Na = 8\n", | |
" rs = P_inner_size\n", | |
" cs = Q_inner_size\n", | |
" P = np.random.rand(Nr, Na, rs, rs) > 0.5\n", | |
" Q = np.random.rand(Nr, Na, cs, cs) > 0.5\n", | |
" K = np.random.rand(rs, cs)\n", | |
" # K is interpreted as a probability distribution, so it should sum to 1.\n", | |
" K = K / K.sum()\n", | |
"\n", | |
" # Optimisation: casting boolean array to float64 speeds up G() by ~4x\n", | |
" P = P.astype(\"float64\")\n", | |
" Q = Q.astype(\"float64\")\n", | |
" \n", | |
" return (K, P, Q)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# These two functions are intended to give the same results, less rounding errors.\n", | |
"\n", | |
"def G(K, P, Q):\n", | |
" \"Equation 5, section II.C\"\n", | |
" SUM = np.zeros_like(K)\n", | |
" for a in range(P.shape[1]):\n", | |
" for b in range(P.shape[0]):\n", | |
" SUM = SUM + P[a, b] @ K @ Q[a, b].transpose() \\\n", | |
" + P[a, b].transpose() @ K @ Q[a, b]\n", | |
" return SUM\n", | |
"\n", | |
"def G_dot(K, P, Q):\n", | |
" \"Equation 5, section II.C\"\n", | |
" SUM = np.zeros_like(K)\n", | |
" for a in range(P.shape[1]):\n", | |
" for b in range(P.shape[0]):\n", | |
" SUM = SUM + P[a, b].dot(K).dot(Q[a, b].transpose()) \\\n", | |
" + P[a, b].transpose().dot(K).dot(Q[a, b])\n", | |
" return SUM\n", | |
"\n", | |
"def G_einsum(K, P, Q):\n", | |
" A = np.einsum(\"abcd,de,abfe\", P, K, Q, optimize=True);\n", | |
" B = np.einsum(\"abdc,de,abef\", P ,K, Q, optimize=True);\n", | |
" return A+B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_data = tuple(map(lambda size: placeholder_data(size, size), (10, 100, 500, 1000)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"337 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"6.08 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
"407 ms ± 29.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
"3.2 s ± 41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"for data in test_data:\n", | |
" %timeit G(*data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"366 µs ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"6.03 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
"398 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
"3.19 s ± 38.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"for data in test_data:\n", | |
" %timeit G_dot(*data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"570 µs ± 2.94 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
"231 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
"26.3 s ± 53.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Size 1000 doesn't terminate for me and size 500 takes ages.\n", | |
"for data in test_data[:-1]:\n", | |
" %timeit G_einsum(*data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Check that the results are roughly the same\n", | |
"for data in test_data[:-1]:\n", | |
" GK = G(*data)\n", | |
" GKd = G_dot(*data)\n", | |
" assert (GK == GKd).all()\n", | |
" GKe = G_einsum(*data)\n", | |
" assert 1e-10 > np.sum(GK - (GKe))/np.sum(GK) # aka \"close enough\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.version.full_version" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment