Skip to content

Instantly share code, notes, and snippets.

@eric-czech
Last active March 2, 2020 12:33
Show Gist options
  • Save eric-czech/7bc258417dfa85e96f270d64ee8f2f7d to your computer and use it in GitHub Desktop.
Save eric-czech/7bc258417dfa85e96f270d64ee8f2f7d to your computer and use it in GitHub Desktop.
Dask kronecker product attempts
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dask Kronecker Product Implementations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:26.195827Z",
"start_time": "2020-02-29T09:03:25.720008Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"import numpy as np\n",
"from dask.base import tokenize\n",
"from dask.highlevelgraph import HighLevelGraph\n",
"from dask import core\n",
"from dask.array.core import operator\n",
"from dask.distributed import Client\n",
"import dask.array as da"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:27.025398Z",
"start_time": "2020-02-29T09:03:26.197844Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3 style=\"text-align: left;\">Client</h3>\n",
"<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n",
" <li><b>Scheduler: </b>tcp://127.0.0.1:8079</li>\n",
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3 style=\"text-align: left;\">Cluster</h3>\n",
"<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n",
" <li><b>Workers: </b>1</li>\n",
" <li><b>Cores: </b>1</li>\n",
" <li><b>Memory: </b>134.78 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: 'tcp://127.0.0.1:8079' processes=1 threads=1, memory=134.78 GB>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client = Client(n_workers=1, threads_per_worker=1, processes=False, scheduler_port=8079)\n",
"client"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:27.031242Z",
"start_time": "2020-02-29T09:03:27.027613Z"
}
},
"outputs": [],
"source": [
"def validate_result(x, y, z):\n",
" z = z.compute()\n",
" assert z.shape == z.shape\n",
" assert np.all(np.kron(x.compute(), y.compute()) == z)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build inputs test test with:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:27.042852Z",
"start_time": "2020-02-29T09:03:27.034107Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tr>\n",
"<td>\n",
"<table>\n",
" <thead>\n",
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr><th> Bytes </th><td> 288 B </td> <td> 48 B </td></tr>\n",
" <tr><th> Shape </th><td> (9, 4) </td> <td> (3, 2) </td></tr>\n",
" <tr><th> Count </th><td> 14 Tasks </td><td> 6 Chunks </td></tr>\n",
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n",
" </tbody>\n",
"</table>\n",
"</td>\n",
"<td>\n",
"<svg width=\"103\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n",
"\n",
" <!-- Horizontal lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"53\" y2=\"0\" style=\"stroke-width:2\" />\n",
" <line x1=\"0\" y1=\"40\" x2=\"53\" y2=\"40\" />\n",
" <line x1=\"0\" y1=\"80\" x2=\"53\" y2=\"80\" />\n",
" <line x1=\"0\" y1=\"120\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Vertical lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n",
" <line x1=\"26\" y1=\"0\" x2=\"26\" y2=\"120\" />\n",
" <line x1=\"53\" y1=\"0\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Colored Rectangle -->\n",
" <polygon points=\"0.000000,0.000000 53.333333,0.000000 53.333333,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n",
"\n",
" <!-- Text -->\n",
" <text x=\"26.666667\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >4</text>\n",
" <text x=\"73.333333\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,73.333333,60.000000)\">9</text>\n",
"</svg>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"dask.array<rechunk-merge, shape=(9, 4), dtype=int64, chunksize=(3, 2), chunktype=numpy.ndarray>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = da.arange(36).reshape(9, 4).rechunk((3, 2))\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:27.053869Z",
"start_time": "2020-02-29T09:03:27.046388Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tr>\n",
"<td>\n",
"<table>\n",
" <thead>\n",
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr><th> Bytes </th><td> 288 B </td> <td> 48 B </td></tr>\n",
" <tr><th> Shape </th><td> (9, 4) </td> <td> (3, 2) </td></tr>\n",
" <tr><th> Count </th><td> 14 Tasks </td><td> 6 Chunks </td></tr>\n",
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n",
" </tbody>\n",
"</table>\n",
"</td>\n",
"<td>\n",
"<svg width=\"103\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n",
"\n",
" <!-- Horizontal lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"53\" y2=\"0\" style=\"stroke-width:2\" />\n",
" <line x1=\"0\" y1=\"40\" x2=\"53\" y2=\"40\" />\n",
" <line x1=\"0\" y1=\"80\" x2=\"53\" y2=\"80\" />\n",
" <line x1=\"0\" y1=\"120\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Vertical lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n",
" <line x1=\"26\" y1=\"0\" x2=\"26\" y2=\"120\" />\n",
" <line x1=\"53\" y1=\"0\" x2=\"53\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Colored Rectangle -->\n",
" <polygon points=\"0.000000,0.000000 53.333333,0.000000 53.333333,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n",
"\n",
" <!-- Text -->\n",
" <text x=\"26.666667\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >4</text>\n",
" <text x=\"73.333333\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,73.333333,60.000000)\">9</text>\n",
"</svg>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"dask.array<rechunk-merge, shape=(9, 4), dtype=int64, chunksize=(3, 2), chunktype=numpy.ndarray>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = da.ones(36, dtype=x.dtype).reshape(9, 4).rechunk((3, 2))\n",
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### V1 - Using ```da.blockwise```"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:27.847418Z",
"start_time": "2020-02-29T09:03:27.056573Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" ...,\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def kron_v1(x, y):\n",
" # Rechunk left array to single data elements\n",
" x = x.rechunk((1, 1))\n",
" return da.blockwise(\n",
" np.multiply, 'ij', x, 'ij', y, 'xy', concatenate=True, dtype='f8',\n",
" adjust_chunks={'i': y.shape[0], 'j': y.shape[1]}\n",
" )\n",
"z = kron_v1(x, y)\n",
"validate_result(x, y, z)\n",
"z.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### V2 - Using ```da.block```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:29.816169Z",
"start_time": "2020-02-29T09:03:27.852893Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" ...,\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def kron_v2(x, y):\n",
" return da.block([\n",
" [x[i, j] * y for j in range(x.shape[1])]\n",
" for i in range(x.shape[0])\n",
" ])\n",
"z = kron_v2(x, y)\n",
"validate_result(x, y, z)\n",
"z.compute()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" ...,\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Validate with np array argument on left as well\n",
"z = kron_v2(x.compute(), y)\n",
"validate_result(x, y, z)\n",
"z.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## V3 - Using custom graph\n",
"\n",
"The intent with this approach is to assemble a new graph representing a kronecker product based on compositions of task nodes in the existing graphs for the input arrays. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:29.824881Z",
"start_time": "2020-02-29T09:03:29.819197Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 1),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 1),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 1),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 0),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 1),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 2),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 3),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 4),\n",
" ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 5),\n",
" ('reshape-126131d495cd2afae9f2085c3fb52292', 0, 0),\n",
" ('arange-e599d6c2e52a053c03cd96483fdb779d', 0)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show all the tasks in the graph to create x\n",
"list(dict(x.__dask_graph__()))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:29.835190Z",
"start_time": "2020-02-29T09:03:29.828630Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),\n",
" ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show only the tasks that represent results\n",
"list(dict(x.__dask_keys__())) "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:03:30.895878Z",
"start_time": "2020-02-29T09:03:29.837588Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" [ 0, 0, 0, ..., 3, 3, 3],\n",
" ...,\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35],\n",
" [32, 32, 32, ..., 35, 35, 35]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# See https://docs.dask.org/en/latest/array-design.html#example-eye-function for an example Array graph construction\n",
"\n",
"def kron_v3(x, y):\n",
" \"\"\" Kronecker Product \n",
" \n",
" Limitations:\n",
" - Assumes equally size chunks within each array\n",
" - Only works for 2D arrays\n",
" \"\"\"\n",
" chunks = (y.chunks[0] * x.shape[0], y.chunks[1] * x.shape[1])\n",
" \n",
" name = 'kron-' + tokenize(x.name, y.name)\n",
" \n",
" def get_chunk_index(i, j):\n",
" \"\"\" Determine index of array coordinates within a chunk \"\"\"\n",
" return (i % x.chunksize[0], j % x.chunksize[1])\n",
" \n",
" def get_chunk_coords(i, j):\n",
" \"\"\" Determine chunk coordinates for array indices \"\"\"\n",
" return (i // x.chunksize[0], j // x.chunksize[1])\n",
" \n",
" def array_idx_op(xi, xj):\n",
" \"\"\" Build task for selecting single element of a matrix provided scalar indices \"\"\"\n",
" chk_idx = get_chunk_index(xi, xj)\n",
" chk_crd = get_chunk_coords(xi, xj)\n",
" # Example selecting item in second column and row of first chunk:\n",
" # (operator.getitem, (x.name, 0, 0), (1, 1))\n",
" return (operator.getitem, (x.name, *chk_crd), chk_idx)\n",
" \n",
" def kron_key(xi, xj, yk):\n",
" \"\"\" Build chunk key for resulting block in product \"\"\"\n",
" # Result has as many blocks in one dimension as there are chunks\n",
" # in y times the length of x along that same dimension\n",
" return (name, xi*len(y.chunks[0]) + yk[1], xj*len(y.chunks[1]) + yk[2])\n",
" \n",
" layer = {\n",
" # Map kron product block to operation with copy of y times single element of x\n",
" # NOTE: It is crucial here that all operations do not refer to x and y directly \n",
" # as this simply nests the graphs for them within this one -- the construction\n",
" # here must instead only refer to keys within the graphs of x and y\n",
" kron_key(xi, xj, yk): (operator.mul, array_idx_op(xi, xj), yk) \n",
" for xi in range(x.shape[0])\n",
" for xj in range(x.shape[1])\n",
" for yk in core.flatten(y.__dask_keys__())\n",
" }\n",
" # Many of the task keys reference above don't exist in the layer dict just created, so it \n",
" # is crucial that x and y are provided here as dependencies so that the graphs can be merged\n",
" dsk = HighLevelGraph.from_collections(name, layer, dependencies=[x, y])\n",
" return da.Array(dsk, name, chunks, dtype=x.dtype)\n",
"\n",
"z = kron_v3(x, y)\n",
"validate_result(x, y, z)\n",
"z.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Benchmarks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a little bigger arrays for benchmarking; note that two 50x20 arrays give a 2500x400 (1M elems) result so these should still remain small for local testing:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:04:43.416290Z",
"start_time": "2020-02-29T09:04:43.408815Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tr>\n",
"<td>\n",
"<table>\n",
" <thead>\n",
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr><th> Bytes </th><td> 8.00 kB </td> <td> 800 B </td></tr>\n",
" <tr><th> Shape </th><td> (50, 20) </td> <td> (25, 4) </td></tr>\n",
" <tr><th> Count </th><td> 22 Tasks </td><td> 10 Chunks </td></tr>\n",
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n",
" </tbody>\n",
"</table>\n",
"</td>\n",
"<td>\n",
"<svg width=\"98\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n",
"\n",
" <!-- Horizontal lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"48\" y2=\"0\" style=\"stroke-width:2\" />\n",
" <line x1=\"0\" y1=\"60\" x2=\"48\" y2=\"60\" />\n",
" <line x1=\"0\" y1=\"120\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Vertical lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n",
" <line x1=\"9\" y1=\"0\" x2=\"9\" y2=\"120\" />\n",
" <line x1=\"19\" y1=\"0\" x2=\"19\" y2=\"120\" />\n",
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" />\n",
" <line x1=\"38\" y1=\"0\" x2=\"38\" y2=\"120\" />\n",
" <line x1=\"48\" y1=\"0\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Colored Rectangle -->\n",
" <polygon points=\"0.000000,0.000000 48.000000,0.000000 48.000000,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n",
"\n",
" <!-- Text -->\n",
" <text x=\"24.000000\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >20</text>\n",
" <text x=\"68.000000\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,68.000000,60.000000)\">50</text>\n",
"</svg>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"dask.array<rechunk-merge, shape=(50, 20), dtype=int64, chunksize=(25, 4), chunktype=numpy.ndarray>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xl = da.arange(50*20).reshape(50, 20).rechunk((25, 4))\n",
"xl"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:04:44.025255Z",
"start_time": "2020-02-29T09:04:44.017163Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<tr>\n",
"<td>\n",
"<table>\n",
" <thead>\n",
" <tr><td> </td><th> Array </th><th> Chunk </th></tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr><th> Bytes </th><td> 8.00 kB </td> <td> 800 B </td></tr>\n",
" <tr><th> Shape </th><td> (50, 20) </td> <td> (25, 4) </td></tr>\n",
" <tr><th> Count </th><td> 22 Tasks </td><td> 10 Chunks </td></tr>\n",
" <tr><th> Type </th><td> int64 </td><td> numpy.ndarray </td></tr>\n",
" </tbody>\n",
"</table>\n",
"</td>\n",
"<td>\n",
"<svg width=\"98\" height=\"170\" style=\"stroke:rgb(0,0,0);stroke-width:1\" >\n",
"\n",
" <!-- Horizontal lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"48\" y2=\"0\" style=\"stroke-width:2\" />\n",
" <line x1=\"0\" y1=\"60\" x2=\"48\" y2=\"60\" />\n",
" <line x1=\"0\" y1=\"120\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Vertical lines -->\n",
" <line x1=\"0\" y1=\"0\" x2=\"0\" y2=\"120\" style=\"stroke-width:2\" />\n",
" <line x1=\"9\" y1=\"0\" x2=\"9\" y2=\"120\" />\n",
" <line x1=\"19\" y1=\"0\" x2=\"19\" y2=\"120\" />\n",
" <line x1=\"28\" y1=\"0\" x2=\"28\" y2=\"120\" />\n",
" <line x1=\"38\" y1=\"0\" x2=\"38\" y2=\"120\" />\n",
" <line x1=\"48\" y1=\"0\" x2=\"48\" y2=\"120\" style=\"stroke-width:2\" />\n",
"\n",
" <!-- Colored Rectangle -->\n",
" <polygon points=\"0.000000,0.000000 48.000000,0.000000 48.000000,120.000000 0.000000,120.000000\" style=\"fill:#ECB172A0;stroke-width:0\"/>\n",
"\n",
" <!-- Text -->\n",
" <text x=\"24.000000\" y=\"140.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" >20</text>\n",
" <text x=\"68.000000\" y=\"60.000000\" font-size=\"1.0rem\" font-weight=\"100\" text-anchor=\"middle\" transform=\"rotate(0,68.000000,60.000000)\">50</text>\n",
"</svg>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"dask.array<rechunk-merge, shape=(50, 20), dtype=int64, chunksize=(25, 4), chunktype=numpy.ndarray>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yl = da.ones(xl.size, dtype=xl.dtype).reshape(xl.shape).rechunk(xl.chunksize)\n",
"yl"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:06:03.973336Z",
"start_time": "2020-02-29T09:04:50.646935Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.47 s ± 67.9 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -n 3 -r 3\n",
"with warnings.catch_warnings():\n",
" # Ignore \"PerformanceWarning: Increasing number of chunks by factor of X\"; \n",
" # we know the rechunking to (1,1) will do this\n",
" warnings.simplefilter(action='ignore', category=da.PerformanceWarning)\n",
" kron_v1(xl, yl).compute()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:12:20.636237Z",
"start_time": "2020-02-29T09:06:03.978754Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"29.1 s ± 266 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -n 3 -r 3\n",
"kron_v2(xl, yl).compute()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:16:04.060013Z",
"start_time": "2020-02-29T09:12:20.639577Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18.1 s ± 217 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -n 3 -r 3\n",
"kron_v3(xl, yl).compute()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:18:25.990804Z",
"start_time": "2020-02-29T09:18:24.858101Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"151 ms ± 8.67 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -n 3 -r 3\n",
"np.kron(xl.compute(), yl.compute())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2020-02-29T09:25:54.659284Z",
"start_time": "2020-02-29T09:25:53.966050Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"103 ms ± 6.34 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"%%timeit -n 3 -r 3\n",
"# Compare to elementwise multiplication of random matrices having shape \n",
"# equal to kronecker product output shape (same number of multiplications)\n",
"xr1 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))\n",
"xr2 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))\n",
"(xr1 * xr2).compute()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"57.9 s ± 1.53 s per loop (mean ± std. dev. of 3 runs, 3 loops each)\n"
]
}
],
"source": [
"# Also test the v2 version with a numpy input on the left --\n",
"# This is almost 2x slower somehow?\n",
"xla = xl.compute()\n",
"%timeit -n 3 -r 3 kron_v2(xla, yl).compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Conclusion**: 1x1 rechunking + ```da.blockwise``` performs better than the others, but all of the methods are extremely slow considering that an elementwise multiplication involving the same number of FLOPs takes 1% as much time (.1/8.47 ~= .01). "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment