Skip to content

Instantly share code, notes, and snippets.

@cmcaine
Last active June 5, 2017 17:01
Show Gist options
  • Save cmcaine/a49054014cef6fb229ba619379b48a3f to your computer and use it in GitHub Desktop.
Save cmcaine/a49054014cef6fb229ba619379b48a3f to your computer and use it in GitHub Desktop.
einsum-efficiency-problem.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"def placeholder_data(P_inner_size, Q_inner_size):\n",
" \"Returns K, P, Q: random matrices with plausible values\"\n",
" Nr = Na = 8\n",
" rs = P_inner_size\n",
" cs = Q_inner_size\n",
" P = np.random.rand(Nr, Na, rs, rs) > 0.5\n",
" Q = np.random.rand(Nr, Na, cs, cs) > 0.5\n",
" K = np.random.rand(rs, cs)\n",
" # K is interpreted as a probability distribution, so it should sum to 1.\n",
" K = K / K.sum()\n",
"\n",
" # Optimisation: casting boolean array to float64 speeds up G() by ~4x\n",
" P = P.astype(\"float64\")\n",
" Q = Q.astype(\"float64\")\n",
" \n",
" return (K, P, Q)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# These two functions are intended to give the same results, less rounding errors.\n",
"\n",
"def G(K, P, Q):\n",
" \"Equation 5, section II.C\"\n",
" SUM = np.zeros_like(K)\n",
" for a in range(P.shape[1]):\n",
" for b in range(P.shape[0]):\n",
" SUM = SUM + P[a, b] @ K @ Q[a, b].transpose() \\\n",
" + P[a, b].transpose() @ K @ Q[a, b]\n",
" return SUM\n",
"\n",
"def G_dot(K, P, Q):\n",
" \"Equation 5, section II.C\"\n",
" SUM = np.zeros_like(K)\n",
" for a in range(P.shape[1]):\n",
" for b in range(P.shape[0]):\n",
" SUM = SUM + P[a, b].dot(K).dot(Q[a, b].transpose()) \\\n",
" + P[a, b].transpose().dot(K).dot(Q[a, b])\n",
" return SUM\n",
"\n",
"def G_einsum(K, P, Q):\n",
" A = np.einsum(\"abcd,de,abfe\", P, K, Q, optimize=True);\n",
" B = np.einsum(\"abdc,de,abef\", P ,K, Q, optimize=True);\n",
" return A+B"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"test_data = tuple(map(lambda size: placeholder_data(size, size), (10, 100, 500, 1000)))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"337 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"6.08 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"407 ms ± 29.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"3.2 s ± 41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for data in test_data:\n",
" %timeit G(*data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"366 µs ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"6.03 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"398 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"3.19 s ± 38.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for data in test_data:\n",
" %timeit G_dot(*data)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"570 µs ± 2.94 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"231 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"26.3 s ± 53.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"# Size 1000 doesn't terminate for me and size 500 takes ages.\n",
"for data in test_data[:-1]:\n",
" %timeit G_einsum(*data)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Check that the results are roughly the same\n",
"for data in test_data[:-1]:\n",
" GK = G(*data)\n",
" GKd = G_dot(*data)\n",
" assert (GK == GKd).all()\n",
" GKe = G_einsum(*data)\n",
" assert 1e-10 > np.sum(GK - (GKe))/np.sum(GK) # aka \"close enough\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.version.full_version"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment