Last active
March 2, 2020 12:33
-
-
Save eric-czech/7bc258417dfa85e96f270d64ee8f2f7d to your computer and use it in GitHub Desktop.
Dask kronecker product attempts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Dask Kronecker Product Implementations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:26.195827Z", | |
"start_time": "2020-02-29T09:03:25.720008Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import warnings\n", | |
"import numpy as np\n", | |
"from dask.base import tokenize\n", | |
"from dask.highlevelgraph import HighLevelGraph\n", | |
"from dask import core\n", | |
"from dask.array.core import operator\n", | |
"from dask.distributed import Client\n", | |
"import dask.array as da" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:27.025398Z", | |
"start_time": "2020-02-29T09:03:26.197844Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table style=\"border: 2px solid white;\">\n", | |
"<tr>\n", | |
"<td style=\"vertical-align: top; border: 0px solid white\">\n", | |
"<h3 style=\"text-align: left;\">Client</h3>\n", | |
"<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n", | |
" <li><b>Scheduler: </b>tcp://127.0.0.1:8079</li>\n", | |
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n", | |
"</ul>\n", | |
"</td>\n", | |
"<td style=\"vertical-align: top; border: 0px solid white\">\n", | |
"<h3 style=\"text-align: left;\">Cluster</h3>\n", | |
"<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n", | |
" <li><b>Workers: </b>1</li>\n", | |
" <li><b>Cores: </b>1</li>\n", | |
" <li><b>Memory: </b>134.78 GB</li>\n", | |
"</ul>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<Client: 'tcp://127.0.0.1:8079' processes=1 threads=1, memory=134.78 GB>" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"client = Client(n_workers=1, threads_per_worker=1, processes=False, scheduler_port=8079)\n", | |
"client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:27.031242Z", | |
"start_time": "2020-02-29T09:03:27.027613Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def validate_result(x, y, z):\n", | |
" z = z.compute()\n", | |
" assert z.shape == z.shape\n", | |
" assert np.all(np.kron(x.compute(), y.compute()) == z)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Build inputs test test with:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:27.042852Z", | |
"start_time": "2020-02-29T09:03:27.034107Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
"<tr>\n", | |
"<td>\n", | |
"<table>\n", | |
" <thead>\n", | |
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr><th> Bytes </th><td> 288 B </td> <td> 48 B </td></tr>\n", | |
" <tr><th> Shape </th><td> (9, 4) </td> <td> (3, 2) </td></tr>\n", | |
" <tr><th> Count </th><td> 14 Tasks </td><td> 6 Chunks </td></tr>\n", | |
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</td>\n", | |
"<td>\n", | |
"<svg width=\"103\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"53\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"40\" x2=\"53\" y2=\"40\" />\n", | |
" <line x1=\"0\" y1=\"80\" x2=\"53\" y2=\"80\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"26\" y1=\"0\" x2=\"26\" y2=\"120\" />\n", | |
" <line x1=\"53\" y1=\"0\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.000000,0.000000 53.333333,0.000000 53.333333,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"26.666667\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >4</text>\n", | |
" <text x=\"73.333333\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,73.333333,60.000000)\">9</text>\n", | |
"</svg>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(9, 4), dtype=int64, chunksize=(3, 2), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = da.arange(36).reshape(9, 4).rechunk((3, 2))\n", | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:27.053869Z", | |
"start_time": "2020-02-29T09:03:27.046388Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
"<tr>\n", | |
"<td>\n", | |
"<table>\n", | |
" <thead>\n", | |
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr><th> Bytes </th><td> 288 B </td> <td> 48 B </td></tr>\n", | |
" <tr><th> Shape </th><td> (9, 4) </td> <td> (3, 2) </td></tr>\n", | |
" <tr><th> Count </th><td> 14 Tasks </td><td> 6 Chunks </td></tr>\n", | |
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</td>\n", | |
"<td>\n", | |
"<svg width=\"103\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"53\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"40\" x2=\"53\" y2=\"40\" />\n", | |
" <line x1=\"0\" y1=\"80\" x2=\"53\" y2=\"80\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"26\" y1=\"0\" x2=\"26\" y2=\"120\" />\n", | |
" <line x1=\"53\" y1=\"0\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.000000,0.000000 53.333333,0.000000 53.333333,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"26.666667\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >4</text>\n", | |
" <text x=\"73.333333\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,73.333333,60.000000)\">9</text>\n", | |
"</svg>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(9, 4), dtype=int64, chunksize=(3, 2), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y = da.ones(36, dtype=x.dtype).reshape(9, 4).rechunk((3, 2))\n", | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### V1 - Using ```da.blockwise```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:27.847418Z", | |
"start_time": "2020-02-29T09:03:27.056573Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" ...,\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def kron_v1(x, y):\n", | |
" # Rechunk left array to single data elements\n", | |
" x = x.rechunk((1, 1))\n", | |
" return da.blockwise(\n", | |
" np.multiply, 'ij', x, 'ij', y, 'xy', concatenate=True, dtype='f8',\n", | |
" adjust_chunks={'i': y.shape[0], 'j': y.shape[1]}\n", | |
" )\n", | |
"z = kron_v1(x, y)\n", | |
"validate_result(x, y, z)\n", | |
"z.compute()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### V2 - Using ```da.block```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:29.816169Z", | |
"start_time": "2020-02-29T09:03:27.852893Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" ...,\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def kron_v2(x, y):\n", | |
" return da.block([\n", | |
" [x[i, j] * y for j in range(x.shape[1])]\n", | |
" for i in range(x.shape[0])\n", | |
" ])\n", | |
"z = kron_v2(x, y)\n", | |
"validate_result(x, y, z)\n", | |
"z.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" ...,\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35]])" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Validate with np array argument on left as well\n", | |
"z = kron_v2(x.compute(), y)\n", | |
"validate_result(x, y, z)\n", | |
"z.compute()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## V3 - Using custom graph\n", | |
"\n", | |
"The intent with this approach is to assemble a new graph representing a kronecker product based on compositions of task nodes in the existing graphs for the input arrays. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:29.824881Z", | |
"start_time": "2020-02-29T09:03:29.819197Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 1),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 1),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 1),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 0),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 1),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 2),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 3),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 4),\n", | |
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 5),\n", | |
" ('reshape-126131d495cd2afae9f2085c3fb52292', 0, 0),\n", | |
" ('arange-e599d6c2e52a053c03cd96483fdb779d', 0)]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Show all the tasks in the graph to create x\n", | |
"list(dict(x.__dask_graph__()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:29.835190Z", | |
"start_time": "2020-02-29T09:03:29.828630Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),\n", | |
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0)]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Show only the tasks that represent results\n", | |
"list(dict(x.__dask_keys__())) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:03:30.895878Z", | |
"start_time": "2020-02-29T09:03:29.837588Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" [ 0, 0, 0, ..., 3, 3, 3],\n", | |
" ...,\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35],\n", | |
" [32, 32, 32, ..., 35, 35, 35]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# See https://docs.dask.org/en/latest/array-design.html#example-eye-function for an example Array graph construction\n", | |
"\n", | |
"def kron_v3(x, y):\n", | |
" \"\"\" Kronecker Product \n", | |
" \n", | |
" Limitations:\n", | |
" - Assumes equally size chunks within each array\n", | |
" - Only works for 2D arrays\n", | |
" \"\"\"\n", | |
" chunks = (y.chunks[0] * x.shape[0], y.chunks[1] * x.shape[1])\n", | |
" \n", | |
" name = 'kron-' + tokenize(x.name, y.name)\n", | |
" \n", | |
" def get_chunk_index(i, j):\n", | |
" \"\"\" Determine index of array coordinates within a chunk \"\"\"\n", | |
" return (i % x.chunksize[0], j % x.chunksize[1])\n", | |
" \n", | |
" def get_chunk_coords(i, j):\n", | |
" \"\"\" Determine chunk coordinates for array indices \"\"\"\n", | |
" return (i // x.chunksize[0], j // x.chunksize[1])\n", | |
" \n", | |
" def array_idx_op(xi, xj):\n", | |
" \"\"\" Build task for selecting single element of a matrix provided scalar indices \"\"\"\n", | |
" chk_idx = get_chunk_index(xi, xj)\n", | |
" chk_crd = get_chunk_coords(xi, xj)\n", | |
" # Example selecting item in second column and row of first chunk:\n", | |
" # (operator.getitem, (x.name, 0, 0), (1, 1))\n", | |
" return (operator.getitem, (x.name, *chk_crd), chk_idx)\n", | |
" \n", | |
" def kron_key(xi, xj, yk):\n", | |
" \"\"\" Build chunk key for resulting block in product \"\"\"\n", | |
" # Result has as many blocks in one dimension as there are chunks\n", | |
" # in y times the length of x along that same dimension\n", | |
" return (name, xi*len(y.chunks[0]) + yk[1], xj*len(y.chunks[1]) + yk[2])\n", | |
" \n", | |
" layer = {\n", | |
" # Map kron product block to operation with copy of y times single element of x\n", | |
" # NOTE: It is crucial here that all operations do not refer to x and y directly \n", | |
" # as this simply nests the graphs for them within this one -- the construction\n", | |
" # here must instead only refer to keys within the graphs of x and y\n", | |
" kron_key(xi, xj, yk): (operator.mul, array_idx_op(xi, xj), yk) \n", | |
" for xi in range(x.shape[0])\n", | |
" for xj in range(x.shape[1])\n", | |
" for yk in core.flatten(y.__dask_keys__())\n", | |
" }\n", | |
" # Many of the task keys reference above don't exist in the layer dict just created, so it \n", | |
" # is crucial that x and y are provided here as dependencies so that the graphs can be merged\n", | |
" dsk = HighLevelGraph.from_collections(name, layer, dependencies=[x, y])\n", | |
" return da.Array(dsk, name, chunks, dtype=x.dtype)\n", | |
"\n", | |
"z = kron_v3(x, y)\n", | |
"validate_result(x, y, z)\n", | |
"z.compute()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Benchmarks" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Create a little bigger arrays for benchmarking; note that two 50x20 arrays give a 2500x400 (1M elems) result so these should still remain small for local testing:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:04:43.416290Z", | |
"start_time": "2020-02-29T09:04:43.408815Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
"<tr>\n", | |
"<td>\n", | |
"<table>\n", | |
" <thead>\n", | |
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr><th> Bytes </th><td> 8.00 kB </td> <td> 800 B </td></tr>\n", | |
" <tr><th> Shape </th><td> (50, 20) </td> <td> (25, 4) </td></tr>\n", | |
" <tr><th> Count </th><td> 22 Tasks </td><td> 10 Chunks </td></tr>\n", | |
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</td>\n", | |
"<td>\n", | |
"<svg width=\"98\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"48\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"60\" x2=\"48\" y2=\"60\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"9\" y1=\"0\" x2=\"9\" y2=\"120\" />\n", | |
" <line x1=\"19\" y1=\"0\" x2=\"19\" y2=\"120\" />\n", | |
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" />\n", | |
" <line x1=\"38\" y1=\"0\" x2=\"38\" y2=\"120\" />\n", | |
" <line x1=\"48\" y1=\"0\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.000000,0.000000 48.000000,0.000000 48.000000,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"24.000000\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >20</text>\n", | |
" <text x=\"68.000000\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,68.000000,60.000000)\">50</text>\n", | |
"</svg>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(50, 20), dtype=int64, chunksize=(25, 4), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xl = da.arange(50*20).reshape(50, 20).rechunk((25, 4))\n", | |
"xl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:04:44.025255Z", | |
"start_time": "2020-02-29T09:04:44.017163Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table>\n", | |
"<tr>\n", | |
"<td>\n", | |
"<table>\n", | |
" <thead>\n", | |
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr><th> Bytes </th><td> 8.00 kB </td> <td> 800 B </td></tr>\n", | |
" <tr><th> Shape </th><td> (50, 20) </td> <td> (25, 4) </td></tr>\n", | |
" <tr><th> Count </th><td> 22 Tasks </td><td> 10 Chunks </td></tr>\n", | |
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</td>\n", | |
"<td>\n", | |
"<svg width=\"98\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n", | |
"\n", | |
" <!-- Horizontal lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"48\" y2=\"0\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"0\" y1=\"60\" x2=\"48\" y2=\"60\" />\n", | |
" <line x1=\"0\" y1=\"120\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Vertical lines -->\n", | |
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
" <line x1=\"9\" y1=\"0\" x2=\"9\" y2=\"120\" />\n", | |
" <line x1=\"19\" y1=\"0\" x2=\"19\" y2=\"120\" />\n", | |
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" />\n", | |
" <line x1=\"38\" y1=\"0\" x2=\"38\" y2=\"120\" />\n", | |
" <line x1=\"48\" y1=\"0\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n", | |
"\n", | |
" <!-- Colored Rectangle -->\n", | |
" <polygon points=\"0.000000,0.000000 48.000000,0.000000 48.000000,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n", | |
"\n", | |
" <!-- Text -->\n", | |
" <text x=\"24.000000\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >20</text>\n", | |
" <text x=\"68.000000\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,68.000000,60.000000)\">50</text>\n", | |
"</svg>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"dask.array<rechunk-merge, shape=(50, 20), dtype=int64, chunksize=(25, 4), chunktype=numpy.ndarray>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"yl = da.ones(xl.size, dtype=xl.dtype).reshape(xl.shape).rechunk(xl.chunksize)\n", | |
"yl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:06:03.973336Z", | |
"start_time": "2020-02-29T09:04:50.646935Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"8.47 s ± 67.9 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 3 -r 3\n", | |
"with warnings.catch_warnings():\n", | |
" # Ignore \"PerformanceWarning: Increasing number of chunks by factor of X\"; \n", | |
" # we know the rechunking to (1,1) will do this\n", | |
" warnings.simplefilter(action='ignore', category=da.PerformanceWarning)\n", | |
" kron_v1(xl, yl).compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:12:20.636237Z", | |
"start_time": "2020-02-29T09:06:03.978754Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"29.1 s ± 266 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 3 -r 3\n", | |
"kron_v2(xl, yl).compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:16:04.060013Z", | |
"start_time": "2020-02-29T09:12:20.639577Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18.1 s ± 217 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 3 -r 3\n", | |
"kron_v3(xl, yl).compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:18:25.990804Z", | |
"start_time": "2020-02-29T09:18:24.858101Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"151 ms ± 8.67 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 3 -r 3\n", | |
"np.kron(xl.compute(), yl.compute())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2020-02-29T09:25:54.659284Z", | |
"start_time": "2020-02-29T09:25:53.966050Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"103 ms ± 6.34 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 3 -r 3\n", | |
"# Compare to elementwise multiplication of random matrices having shape \n", | |
"# equal to kronecker product output shape (same number of multiplications)\n", | |
"xr1 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))\n", | |
"xr2 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))\n", | |
"(xr1 * xr2).compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"57.9 s ± 1.53 s per loop (mean ± std. dev. of 3 runs, 3 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Also test the v2 version with a numpy input on the left --\n", | |
"# This is almost 2x slower somehow?\n", | |
"xla = xl.compute()\n", | |
"%timeit -n 3 -r 3 kron_v2(xla, yl).compute()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Conclusion**: 1x1 rechunking + ```da.blockwise``` performs better than the others, but all of the methods are extremely slow considering that an elementwise multiplication involving the same number of FLOPs takes 1% as much time (.1/8.47 ~= .01). " | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment