Skip to content

Instantly share code, notes, and snippets.

@crusaderky
Created November 1, 2023 13:57
Show Gist options
  • Save crusaderky/3e11fd4be8b61d06109a01781dda9c83 to your computer and use it in GitHub Desktop.
Save crusaderky/3e11fd4be8b61d06109a01781dda9c83 to your computer and use it in GitHub Desktop.
dask/distributed#8318 - Speed up network transfer of small buffers
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b504f018-cd4f-40b7-a6e6-b50d8e15c588",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import pickle\n",
"import numpy\n",
"from dask.sizeof import sizeof\n",
"from dask.utils import format_bytes\n",
"from distributed.comm import connect\n",
"from distributed.comm.tcp import TCPListener\n",
"from distributed.protocol.serialize import to_serialize\n",
"from distributed.shuffle._buffer import _List\n",
"\n",
"\n",
"serializers = {\n",
" None: (lambda p: p, lambda p: p),\n",
" \"whole\": (pickle.dumps, pickle.loads),\n",
" \"elements\": (\n",
" lambda p: type(p)(map(pickle.dumps, p)),\n",
" lambda p: type(p)(map(pickle.loads, p)),\n",
" ),\n",
"}\n",
"\n",
"\n",
"async def handle_comm(comm):\n",
" while True:\n",
" msg = await comm.read()\n",
" if msg[\"op\"] == \"close\":\n",
" await comm.close()\n",
" return\n",
"\n",
" s, d = serializers[msg[\"serialize\"]]\n",
" data = msg[\"data\"]\n",
" data = to_serialize(s(d(data)))\n",
" msg = {\"op\": \"pong\", \"data\": data}\n",
" await comm.write(msg)\n",
"\n",
"\n",
"listener = await TCPListener(\"127.0.0.1\", handle_comm)\n",
"comm = await connect(listener.contact_address)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "91d5923e-7cba-4fad-aca7-f5be020c24d3",
"metadata": {},
"outputs": [],
"source": [
"async def bench(label, payload, serialize):\n",
" nb = sum(map(sizeof, payload))\n",
"\n",
" t0 = time.time()\n",
" N = 500\n",
" for _ in range(N):\n",
" s, d = serializers[serialize]\n",
" data = to_serialize(s(payload))\n",
"\n",
" await comm.write({\"op\": \"ping\", \"serialize\": serialize, \"data\": data})\n",
" msg = await comm.read()\n",
" assert msg[\"op\"] == \"pong\"\n",
" data = d(msg[\"data\"])\n",
"\n",
" t1 = time.time()\n",
" bandwidth = (2 * N * nb) / (t1 - t0) / 2**20\n",
" print(label, f\"{bandwidth:4.0f} MiB/s\")\n",
"\n",
" assert len(payload) == len(data)\n",
" assert all((a == b).all() for a, b in zip(payload, data))\n",
" return bandwidth\n",
"\n",
"async def bench_suite(shard_size):\n",
" max_message_size = 2 * 2**20\n",
"\n",
" payload = [numpy.random.random(shard_size // 8) for _ in range(max(1, max_message_size // shard_size))]\n",
" print(\"=\" * 80)\n",
" print(format_bytes(sum(map(sizeof, payload))), \"payload;\", len(payload), \"x\", format_bytes(shard_size), \"shards\")\n",
"\n",
" return [\n",
" shard_size,\n",
" await bench(\"list[ndarray] \", payload, None),\n",
" await bench(\"_List[ndarray]\", _List(payload), None),\n",
" await bench(\"bytes \", payload, \"whole\"),\n",
" await bench(\"list[bytes] \", payload, \"elements\"),\n",
" await bench(\"_List[bytes] \", _List(payload), \"elements\"),\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "083f74ec-abc6-46e3-a9c8-6fdafef28ed3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================================================================\n",
"2.00 MiB payload; 1024 x 2.00 kiB shards\n",
"list[ndarray] 51 MiB/s\n",
"_List[ndarray] 70 MiB/s\n",
"bytes 343 MiB/s\n",
"list[bytes] 57 MiB/s\n",
"_List[bytes] 157 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 512 x 4.00 kiB shards\n",
"list[ndarray] 100 MiB/s\n",
"_List[ndarray] 135 MiB/s\n",
"bytes 598 MiB/s\n",
"list[bytes] 109 MiB/s\n",
"_List[bytes] 258 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 256 x 8.00 kiB shards\n",
"list[ndarray] 191 MiB/s\n",
"_List[ndarray] 264 MiB/s\n",
"bytes 894 MiB/s\n",
"list[bytes] 200 MiB/s\n",
"_List[bytes] 417 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 128 x 16.00 kiB shards\n",
"list[ndarray] 371 MiB/s\n",
"_List[ndarray] 493 MiB/s\n",
"bytes 1101 MiB/s\n",
"list[bytes] 359 MiB/s\n",
"_List[bytes] 577 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 64 x 32.00 kiB shards\n",
"list[ndarray] 670 MiB/s\n",
"_List[ndarray] 851 MiB/s\n",
"bytes 819 MiB/s\n",
"list[bytes] 584 MiB/s\n",
"_List[bytes] 746 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 32 x 64.00 kiB shards\n",
"list[ndarray] 1067 MiB/s\n",
"_List[ndarray] 1282 MiB/s\n",
"bytes 1339 MiB/s\n",
"list[bytes] 834 MiB/s\n",
"_List[bytes] 859 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 16 x 128.00 kiB shards\n",
"list[ndarray] 1559 MiB/s\n",
"_List[ndarray] 1777 MiB/s\n",
"bytes 1414 MiB/s\n",
"list[bytes] 1073 MiB/s\n",
"_List[bytes] 872 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 8 x 256.00 kiB shards\n",
"list[ndarray] 2061 MiB/s\n",
"_List[ndarray] 2190 MiB/s\n",
"bytes 1452 MiB/s\n",
"list[bytes] 1313 MiB/s\n",
"_List[bytes] 952 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 4 x 512.00 kiB shards\n",
"list[ndarray] 2508 MiB/s\n",
"_List[ndarray] 2563 MiB/s\n",
"bytes 1589 MiB/s\n",
"list[bytes] 1403 MiB/s\n",
"_List[bytes] 931 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 2 x 1.00 MiB shards\n",
"list[ndarray] 2815 MiB/s\n",
"_List[ndarray] 2810 MiB/s\n",
"bytes 1487 MiB/s\n",
"list[bytes] 1549 MiB/s\n",
"_List[bytes] 924 MiB/s\n",
"================================================================================\n",
"2.00 MiB payload; 1 x 2.00 MiB shards\n",
"list[ndarray] 2953 MiB/s\n",
"_List[ndarray] 2863 MiB/s\n",
"bytes 1564 MiB/s\n",
"list[bytes] 1664 MiB/s\n",
"_List[bytes] 1012 MiB/s\n",
"================================================================================\n",
"4.00 MiB payload; 1 x 4.00 MiB shards\n",
"list[ndarray] 3232 MiB/s\n",
"_List[ndarray] 3178 MiB/s\n",
"bytes 937 MiB/s\n",
"list[bytes] 943 MiB/s\n",
"_List[bytes] 600 MiB/s\n",
"================================================================================\n",
"8.00 MiB payload; 1 x 8.00 MiB shards\n",
"list[ndarray] 2997 MiB/s\n",
"_List[ndarray] 2945 MiB/s\n",
"bytes 671 MiB/s\n",
"list[bytes] 658 MiB/s\n",
"_List[bytes] 439 MiB/s\n",
"================================================================================\n",
"16.00 MiB payload; 1 x 16.00 MiB shards\n",
"list[ndarray] 3207 MiB/s\n",
"_List[ndarray] 3129 MiB/s\n",
"bytes 660 MiB/s\n",
"list[bytes] 660 MiB/s\n",
"_List[bytes] 433 MiB/s\n"
]
}
],
"source": [
"shard_size = 2048\n",
"rows = []\n",
"while shard_size <= 16 * 2**20:\n",
" row = await bench_suite(shard_size)\n",
" rows.append(row)\n",
" shard_size *= 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0c4ae109-a76c-4e67-80f2-5385f2da55cd",
"metadata": {},
"outputs": [],
"source": [
"await comm.write({\"op\": \"close\"})\n",
"await comm.close()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5cec89e5-8793-4c5c-889c-4657e01ad16e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>list[ndarray]</th>\n",
" <th>opaque list[ndarray]</th>\n",
" <th>bytes</th>\n",
" <th>list[bytes]</th>\n",
" <th>opaque list[bytes]</th>\n",
" </tr>\n",
" <tr>\n",
" <th>shard size (kiB)</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>50.903776</td>\n",
" <td>69.728695</td>\n",
" <td>343.086332</td>\n",
" <td>57.066776</td>\n",
" <td>157.222938</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100.316773</td>\n",
" <td>134.653209</td>\n",
" <td>598.372030</td>\n",
" <td>109.309751</td>\n",
" <td>258.107037</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>190.839629</td>\n",
" <td>263.591943</td>\n",
" <td>894.373489</td>\n",
" <td>200.283609</td>\n",
" <td>417.440357</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>371.209654</td>\n",
" <td>493.385479</td>\n",
" <td>1100.924116</td>\n",
" <td>358.662300</td>\n",
" <td>576.616526</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>670.387683</td>\n",
" <td>851.325741</td>\n",
" <td>818.691276</td>\n",
" <td>583.892921</td>\n",
" <td>746.006159</td>\n",
" </tr>\n",
" <tr>\n",
" <th>64</th>\n",
" <td>1067.119197</td>\n",
" <td>1281.562256</td>\n",
" <td>1338.790461</td>\n",
" <td>833.731683</td>\n",
" <td>858.750140</td>\n",
" </tr>\n",
" <tr>\n",
" <th>128</th>\n",
" <td>1558.516878</td>\n",
" <td>1777.398838</td>\n",
" <td>1413.686997</td>\n",
" <td>1073.416605</td>\n",
" <td>872.072096</td>\n",
" </tr>\n",
" <tr>\n",
" <th>256</th>\n",
" <td>2060.841518</td>\n",
" <td>2190.007782</td>\n",
" <td>1451.755557</td>\n",
" <td>1313.176872</td>\n",
" <td>952.193650</td>\n",
" </tr>\n",
" <tr>\n",
" <th>512</th>\n",
" <td>2508.483760</td>\n",
" <td>2562.543932</td>\n",
" <td>1588.998893</td>\n",
" <td>1402.692111</td>\n",
" <td>930.712748</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1024</th>\n",
" <td>2814.930399</td>\n",
" <td>2809.793638</td>\n",
" <td>1486.866258</td>\n",
" <td>1548.600527</td>\n",
" <td>924.047949</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2048</th>\n",
" <td>2952.556268</td>\n",
" <td>2862.814638</td>\n",
" <td>1564.160716</td>\n",
" <td>1664.199645</td>\n",
" <td>1012.460643</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4096</th>\n",
" <td>3232.150866</td>\n",
" <td>3177.575850</td>\n",
" <td>936.658378</td>\n",
" <td>943.470656</td>\n",
" <td>599.556940</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8192</th>\n",
" <td>2996.758480</td>\n",
" <td>2945.381580</td>\n",
" <td>670.755918</td>\n",
" <td>658.465733</td>\n",
" <td>438.512926</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16384</th>\n",
" <td>3206.737483</td>\n",
" <td>3128.812223</td>\n",
" <td>659.806899</td>\n",
" <td>659.605696</td>\n",
" <td>433.148824</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" list[ndarray] opaque list[ndarray] bytes \\\n",
"shard size (kiB) \n",
"2 50.903776 69.728695 343.086332 \n",
"4 100.316773 134.653209 598.372030 \n",
"8 190.839629 263.591943 894.373489 \n",
"16 371.209654 493.385479 1100.924116 \n",
"32 670.387683 851.325741 818.691276 \n",
"64 1067.119197 1281.562256 1338.790461 \n",
"128 1558.516878 1777.398838 1413.686997 \n",
"256 2060.841518 2190.007782 1451.755557 \n",
"512 2508.483760 2562.543932 1588.998893 \n",
"1024 2814.930399 2809.793638 1486.866258 \n",
"2048 2952.556268 2862.814638 1564.160716 \n",
"4096 3232.150866 3177.575850 936.658378 \n",
"8192 2996.758480 2945.381580 670.755918 \n",
"16384 3206.737483 3128.812223 659.806899 \n",
"\n",
" list[bytes] opaque list[bytes] \n",
"shard size (kiB) \n",
"2 57.066776 157.222938 \n",
"4 109.309751 258.107037 \n",
"8 200.283609 417.440357 \n",
"16 358.662300 576.616526 \n",
"32 583.892921 746.006159 \n",
"64 833.731683 858.750140 \n",
"128 1073.416605 872.072096 \n",
"256 1313.176872 952.193650 \n",
"512 1402.692111 930.712748 \n",
"1024 1548.600527 924.047949 \n",
"2048 1664.199645 1012.460643 \n",
"4096 943.470656 599.556940 \n",
"8192 658.465733 438.512926 \n",
"16384 659.605696 433.148824 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas\n",
"df = pandas.DataFrame(rows, columns=[\n",
" \"shard size (kiB)\",\n",
" \"list[ndarray]\",\n",
" \"opaque list[ndarray]\",\n",
" \"bytes\",\n",
" \"list[bytes]\",\n",
" \"opaque list[bytes]\",\n",
"])\n",
"df[\"shard size (kiB)\"] //= 1024\n",
"df[\"shard size (kiB)\"] = df[\"shard size (kiB)\"].astype(str)\n",
"df = df.set_index(\"shard size (kiB)\")\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "492eb846-7ba4-4496-b3e5-f34991b64120",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: xlabel='shard size (kiB)', ylabel='Bandwidth (MiB/s)'>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df.plot(grid=True, ylabel=\"Bandwidth (MiB/s)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0f5f20-326a-4d70-9919-15e1532a64f9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment