Skip to content

Instantly share code, notes, and snippets.

@rjzamora
Last active February 16, 2022 16:30
Show Gist options
  • Save rjzamora/c56989efd5b0cfc13537f889f505d035 to your computer and use it in GitHub Desktop.
Save rjzamora/c56989efd5b0cfc13537f889f505d035 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a0454883-d88d-4997-bafe-5abd4daa9134",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import dask\n",
"import dask.dataframe as dd\n",
"from dask.utils_test import hlg_layer\n",
"import operator"
]
},
{
"cell_type": "markdown",
"id": "6cf6c55b-d851-4db2-b1d8-a5c1eca1a701",
"metadata": {},
"source": [
"# Dask Design Proposal: General support for predicate pushdown\n",
"\n",
"**Author**: Rick Zamora (Last update: February 14th, 2022)"
]
},
{
"cell_type": "markdown",
"id": "1c3ee9ae-1cb0-44e9-8cfa-92025c238b0e",
"metadata": {
"tags": []
},
"source": [
"The purpose of the proposed changes is to enable certain `HighLevelGraph` (HLG) layers to store and utilize the necessary information for the \"pushdown\" of general filter expressions. HLG Layers are currently used to define the necessary task graph required for a Dask collection. However, information only flows in one direction between the collection API and the graph (that is: it flows from the collection API to the graph). This limitation is problematic when it comes to high-level optimizations like \"predicate pushdown,\" because the application of filters at the root IO layer can modify the initial partition count and index (`divisions`). In other words, predicate pushdown requires the entire graph to be completely regenerated.\n",
"\n",
"Given that the entire graph needs to be regenerated when a filter is pushed down into the IO layer, it seems likely that the best time to apply the optimization is as soon as a `getitem[Series]` operation is executed in the Dask-Dataframe API. For example [dask/dask#8633](https://github.com/dask/dask/pull/8633) shows how this approach can be used with the existing `Layer`/`Blockwise` implementations to support simple filters like `ddf[ddf[\"b\"] < 10]`. The primary purpose of this proposal is to simplify and generalize the brute-force logic currently used in that PR.\n",
"\n",
"In order to capture arbitrary filter expressions, I propose three seperate (but related) changes to Dask:\n",
"\n",
"- **Proposal 1**: Enable collection/graph regeneration\n",
" - Add new attributes to `Layer`/`Blockwise` to enable collection regeneration\n",
" - Expand/formalize the usage of `creation_info` in `Blockwise`\n",
"- **Proposal 2** Expand dispatching machinery and enable filter-expression extraction\n",
" - Add `on_literal=True` option to `Dispatch` to enable dispatching by the literal value of a function argument (rather than the type). This change allows both Dask a down-stream libraries to register filter expressions for specific callable objects (like `operator.gt`)\n",
" - Add new method to `Layer`/`Blockwise` to enable recursive filter-expression extraction\n",
"- **Poposal 3** Implement eager predicate pushdown\n",
" - Use the above changes to push down general filter operations to the IO layer when a `getitem[Series]` operation is executed in the Dask-Dataframe API"
]
},
{
"cell_type": "markdown",
"id": "b5daa9cb-f6b2-4428-84bc-ea9d5abc20ae",
"metadata": {},
"source": [
"## Proposal 1: Enable Collection Regeneration\n",
"\n",
"In order to enable predicate pushdown for general filter operations, it is clear that we need a simple but reliable mechanism for graph regeneration. Since the original graph is typically produced using a collection API, the simplest way to regenerate a graph is to \"replay\" the same collection-API functin calls. One straightforward way to accomplish this is to (optionally) define a `creation_info` attribute at the `Layer` level, storing the original function and arguments used to generate the Dask collection attached to that Layer. Note that this attribute is already defined within the `DataFrameIOLayer` implementation, but has yet to be leveraged for graph optimization purposes. This document proposes that we expand the usage of `creation_info`, by defining it for a variety of simple filtering-related `Blockwise` layers (e.g. `getitem`, and `gt`), and that the following attributes be added to `HighLevelGraph.Layer`:\n",
"\n",
"- `Layer._regenerable` (a property)\n",
" - **description**:\n",
" - This propery tells whether or not the layer is capable of regenerating itself by regenerating the corresponding Dask collection. If `True`, the `_regenerate_collection` method must be implemented.\n",
" - **Default behavior**:\n",
" - `return False`\n",
"- `Layer._regenerate_collection` (an abstract method)\n",
" - **description**:\n",
" - If `_regenerable` is `True`, this method should regenerate and return the appropriate Dask collection for the layer.\n",
" - **Inputs**:\n",
" - `dsk` (`HighLevelGraph`): The full graph (layers and dependencies) for the original collection.\n",
" - `new_kwargs` (`dict[dict]`); Optional: New (non-default) key-word argument to use for the regeneration of each Layer. Keys correspond to original Layer names, and the values correspond to a dictionary of key-word arguments to be used to update default values. For example, if Layer `\"read-parquet-86753\"` was originally generated with `read_parquet(...,filters=None,columns=[\"a\"])`, then `new_kwargs={\"filters\": [(\"a\", \"<\", 10)]}` will update the `filters` argument to `read_parquet`, but will leave the original `columns` argument.\n",
" - `_regen_cache` (`dict`): A (private) regeneration cache to be used in recursive calls to avoid redundant collection generation.\n",
" - **Default behavior**:\n",
" - `raise NotImplementedError`\n",
"\n",
"We define the `_regenerable` property in the base `Layer` class (rather than in `Blockwise`) so that we can ask **any** layer if it is \"regenerable\". Similarly, we define `_regenerate_collection` in `Layer` because non-`Blockwise` classes may want to implement this feature in the future.\n",
"\n",
"**NOTE**: Another alternative to this `Layer._regenerable` approach is to define a distinct class for \"regenerable\" layers. However, that approach will likely lead to a messy class-inheritance structure if/when non-`Blockwise` layers begin adopting collection-regeneration features. Therefore, my clear preference is to avoid something like a `RegenerableLayer` class."
]
},
{
"cell_type": "markdown",
"id": "c7976877-94d4-45b8-a362-39d4fb213372",
"metadata": {
"tags": []
},
"source": [
"#### Sample Implementation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b0a7b9d9-276a-4f1d-82ad-d275c48df94f",
"metadata": {},
"outputs": [],
"source": [
"def _regenerable(self):\n",
" \"\"\"Whether this layer supports ``_regenerate_collection``\n",
"\n",
" If True, this layer can regenerate a Dask collection\n",
" when provided the required collection inputs.\n",
" \"\"\"\n",
" return False\n",
"\n",
"def _regenerate_collection(\n",
" self,\n",
" dsk: dask.highlevelgraph.HighLevelGraph,\n",
" new_kwargs: dict = None,\n",
" _regen_cache: dict = None,\n",
"):\n",
" \"\"\"Regenerate a Dask collection for this layer using the\n",
" provided inputs and key-word arguments\n",
" \"\"\"\n",
" if self._regenerable:\n",
" raise NotImplementedError\n",
" raise ValueError(\n",
" \"`_regenerate_collection` requires `_regenerable=True`\"\n",
" )\n",
"\n",
"# Monkey patch\n",
"setattr(dask.highlevelgraph.Layer, '_regenerable', property(_regenerable))\n",
"dask.highlevelgraph.Layer._regenerate_collection = _regenerate_collection"
]
},
{
"cell_type": "markdown",
"id": "1f9c78d8-fc2c-4ce1-8e27-e85c089cd52e",
"metadata": {},
"source": [
"### Related Blockwise Changes\n",
"\n",
"Since most IO and filtering operations in Dask-Dataframe consist entirely of `Blockwise`-based HLG layers, effective predicate pushdown clearly depends on the implementation of `_regenerable` and `_regenerate_collection` in `Blockwise`. Assuming that we adopt the `creation_info` attribute in `Blockwise` more generally, the necessary logic becomes quite simple:\n",
"\n",
"- `Blockwise._regenerable`: Return `True` if (and only if) a \"proper\" `creation_info` attribute is defined.\n",
"- `Blockwise._regenerate_collection`: Use `creation_info` to regenerate the Dask collection for the current `Blockwise` layer recursively. Since `Blockwise` keeps track of input collections and literals using an `indices` attribute, it is relatively straightforward for the layer to call `_regenerate_collection` on the required input layers before replaying its own `creation_info` function call. \n",
" "
]
},
{
"cell_type": "markdown",
"id": "4496b01d-74c5-49f7-a1f1-408c1e8fde4f",
"metadata": {},
"source": [
"#### Sample Implementation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b0192fd4-c939-4bb2-b3f1-10ec2302222e",
"metadata": {},
"outputs": [],
"source": [
"def _regenerable(self):\n",
" # creation_info must contain the callable\n",
" # function used to generate the collection\n",
" # terminated by this Layer. The kwargs should\n",
" # also be saved if/when applicable. Positional\n",
" # args should be captured by `indices`\n",
" return \"func\" in self.creation_info\n",
"\n",
"def _regenerate_collection(\n",
" self,\n",
" dsk: dask.highlevelgraph.HighLevelGraph,\n",
" new_kwargs: dict = None,\n",
" _regen_cache: dict = None,\n",
"):\n",
"\n",
" # Return regenerated layer if the work was\n",
" # already done\n",
" _regen_cache = _regen_cache or {}\n",
" if self.output in _regen_cache:\n",
" return _regen_cache[self.output]\n",
"\n",
" # Check that this layer is `_regenerable`\n",
" if not self._regenerable:\n",
" raise ValueError(\n",
" \"`_regenerate_collection` requires `_regenerable=True`\"\n",
" )\n",
"\n",
" # Recursively generate necessary inputs to \n",
" # this layer to generate the collection\n",
" inputs = []\n",
" for key, ind in self.indices:\n",
" if ind is None:\n",
" if isinstance(key, (str, tuple)) and key in dsk.layers:\n",
" continue\n",
" inputs.append(key)\n",
" elif key in self.io_deps:\n",
" continue\n",
" elif dsk.layers[key]._regenerable:\n",
" inputs.append(\n",
" dsk.layers[key]._regenerate_collection(\n",
" dsk,\n",
" new_kwargs=new_kwargs,\n",
" _regen_cache=_regen_cache,\n",
" )\n",
" )\n",
" else:\n",
" raise ValueError(\n",
" \"`_regenerate_collection` failed. \"\n",
" \"Not all HLG layers are regenerable.\"\n",
" )\n",
"\n",
" # Extract the callable func and key-word args.\n",
" # Then return a regenerated collection\n",
" func = self.creation_info[\"func\"]\n",
" regen_kwargs = self.creation_info.get(\"kwargs\", {}).copy()\n",
" regen_kwargs.update((new_kwargs or {}).get(self.output, {}))\n",
" result = func(*inputs, **regen_kwargs)\n",
" _regen_cache[self.output] = result\n",
" return result\n",
"\n",
"# Monkey patch\n",
"setattr(dask.blockwise.Blockwise, '_regenerable', property(_regenerable))\n",
"dask.blockwise.Blockwise._regenerate_collection = _regenerate_collection"
]
},
{
"cell_type": "markdown",
"id": "3a9b624f-2a36-4ebe-8986-e895bfd39cf9",
"metadata": {},
"source": [
"### Example Workflow"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f8206395-1d09-449c-91d1-04b53aac3b75",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.50.0 (20220117.2223)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"247pt\" height=\"476pt\"\n",
" viewBox=\"0.00 0.00 247.00 476.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-472 243,-472 243,4 -4,4\"/>\n",
"<!-- &#45;8825203084944598416 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>&#45;8825203084944598416</title>\n",
"<g id=\"a_node1\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;Number of Partitions: 3&#10;DataFrame Type: pandas&#10;2 DataFrame Columns: [&#39;a&#39;, &#39;b&#39;]&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"150,-468 20,-468 20,-432 150,-432 150,-468\"/>\n",
"<text text-anchor=\"middle\" x=\"85\" y=\"-445\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem&#45;#0</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;2633590387208578626 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>&#45;2633590387208578626</title>\n",
"<g id=\"a_node2\"><a xlink:title=\"A DataFrameIO Layer with 3 Tasks.&#10;Number of Partitions: 3&#10;DataFrame Type: pandas&#10;3 DataFrame Columns: [&#39;a&#39;, &#39;b&#39;, &#39;c&#39;]&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"184,-36 0,-36 0,0 184,0 184,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"92\" y=\"-13\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">read&#45;parquet&#45;#1</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;287342118792325956 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>&#45;287342118792325956</title>\n",
"<g id=\"a_node3\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;Number of Partitions: 3&#10;DataFrame Type: pandas&#10;2 DataFrame Columns: [&#39;a&#39;, &#39;b&#39;]&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"157,-108 27,-108 27,-72 157,-72 157,-108\"/>\n",
"<text text-anchor=\"middle\" x=\"92\" y=\"-85\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem&#45;#2</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;2633590387208578626&#45;&gt;&#45;287342118792325956 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>&#45;2633590387208578626&#45;&gt;&#45;287342118792325956</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M92,-36.3C92,-44.02 92,-53.29 92,-61.89\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"88.5,-61.9 92,-71.9 95.5,-61.9 88.5,-61.9\"/>\n",
"</g>\n",
"<!-- &#45;287342118792325956&#45;&gt;&#45;8825203084944598416 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>&#45;287342118792325956&#45;&gt;&#45;8825203084944598416</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M81.81,-108.16C76.17,-118.29 69.42,-131.54 65,-144 51.48,-182.12 46,-192.55 46,-233 46,-233 46,-233 46,-307 46,-348.44 61.64,-394.33 73.11,-422.42\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"70,-424.07 77.11,-431.93 76.45,-421.35 70,-424.07\"/>\n",
"</g>\n",
"<!-- &#45;7716256153655819261 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>&#45;7716256153655819261</title>\n",
"<g id=\"a_node5\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"204,-180 74,-180 74,-144 204,-144 204,-180\"/>\n",
"<text text-anchor=\"middle\" x=\"139\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem&#45;#4</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;287342118792325956&#45;&gt;&#45;7716256153655819261 -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>&#45;287342118792325956&#45;&gt;&#45;7716256153655819261</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M103.62,-108.3C109.14,-116.53 115.85,-126.52 121.93,-135.58\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"119.04,-137.54 127.52,-143.9 124.85,-133.64 119.04,-137.54\"/>\n",
"</g>\n",
"<!-- 8140518894306660026 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>8140518894306660026</title>\n",
"<g id=\"a_node4\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"175.5,-324 74.5,-324 74.5,-288 175.5,-288 175.5,-324\"/>\n",
"<text text-anchor=\"middle\" x=\"125\" y=\"-301\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">and_&#45;#3</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- 9026215552911565517 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>9026215552911565517</title>\n",
"<g id=\"a_node8\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"175.5,-396 74.5,-396 74.5,-360 175.5,-360 175.5,-396\"/>\n",
"<text text-anchor=\"middle\" x=\"125\" y=\"-373\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">fillna&#45;#7</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- 8140518894306660026&#45;&gt;9026215552911565517 -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>8140518894306660026&#45;&gt;9026215552911565517</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M125,-324.3C125,-332.02 125,-341.29 125,-349.89\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"121.5,-349.9 125,-359.9 128.5,-349.9 121.5,-349.9\"/>\n",
"</g>\n",
"<!-- 5538263221748953944 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>5538263221748953944</title>\n",
"<g id=\"a_node6\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"155,-252 81,-252 81,-216 155,-216 155,-252\"/>\n",
"<text text-anchor=\"middle\" x=\"118\" y=\"-229\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">gt&#45;#5</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;7716256153655819261&#45;&gt;5538263221748953944 -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>&#45;7716256153655819261&#45;&gt;5538263221748953944</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M133.81,-180.3C131.47,-188.1 128.65,-197.49 126.05,-206.17\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"122.65,-205.31 123.13,-215.9 129.36,-207.32 122.65,-205.31\"/>\n",
"</g>\n",
"<!-- &#45;8776560892563676883 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>&#45;8776560892563676883</title>\n",
"<g id=\"a_node7\"><a xlink:title=\"A Blockwise Layer with 3 Tasks.&#10;\">\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"239,-252 173,-252 173,-216 239,-216 239,-252\"/>\n",
"<text text-anchor=\"middle\" x=\"206\" y=\"-229\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">lt&#45;#6</text>\n",
"</a>\n",
"</g>\n",
"</g>\n",
"<!-- &#45;7716256153655819261&#45;&gt;&#45;8776560892563676883 -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>&#45;7716256153655819261&#45;&gt;&#45;8776560892563676883</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M155.56,-180.3C163.68,-188.78 173.59,-199.14 182.47,-208.42\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"180.19,-211.09 189.63,-215.9 185.24,-206.25 180.19,-211.09\"/>\n",
"</g>\n",
"<!-- 5538263221748953944&#45;&gt;8140518894306660026 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>5538263221748953944&#45;&gt;8140518894306660026</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M119.73,-252.3C120.5,-260.02 121.43,-269.29 122.29,-277.89\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"118.81,-278.29 123.29,-287.9 125.78,-277.6 118.81,-278.29\"/>\n",
"</g>\n",
"<!-- &#45;8776560892563676883&#45;&gt;8140518894306660026 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>&#45;8776560892563676883&#45;&gt;8140518894306660026</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M185.98,-252.3C175.87,-261.03 163.46,-271.76 152.48,-281.25\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"150.07,-278.71 144.79,-287.9 154.65,-284.01 150.07,-278.71\"/>\n",
"</g>\n",
"<!-- 9026215552911565517&#45;&gt;&#45;8825203084944598416 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>9026215552911565517&#45;&gt;&#45;8825203084944598416</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M115.11,-396.3C110.51,-404.36 104.94,-414.11 99.85,-423.02\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"96.7,-421.48 94.77,-431.9 102.77,-424.95 96.7,-421.48\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f1b76286af0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Example workflow with filtering after read_parquet\n",
"ddf = dd.read_parquet(\"./tmpdir\")[[\"a\", \"b\"]]\n",
"filters = operator.and_(\n",
" ddf[\"b\"]>5,\n",
" ddf[\"b\"]<10,\n",
").fillna(False)\n",
"ddf = ddf[filters]\n",
"\n",
"# Patch layers to define \"proper\" `creation_info`\n",
"def _patch_layers(df):\n",
" for k, layer in df.dask.copy().layers.items():\n",
" if hasattr(layer, \"creation_info\"):\n",
" # This is the read_parquet layer.\n",
" # Need to move the `path` arg to be in \"kwargs\"\n",
" path = layer.creation_info[\"args\"][0]\n",
" kwargs = {\"path\": path}\n",
" kwargs.update(layer.creation_info[\"kwargs\"])\n",
" layer.creation_info[\"kwargs\"] = kwargs\n",
" else:\n",
" # This is not the IO layer.\n",
" # Need to define `creation_info`\n",
" func = layer.dsk[k][0]\n",
" if func == dask.utils.apply:\n",
" func = layer.dsk[k][1]\n",
" if (\n",
" isinstance(func, dask.utils.methodcaller) and\n",
" func.func == \"fillna\"\n",
" ):\n",
" # `methodcaller` makes \"patching\" tricky.\n",
" # In practice, we would certainly want `creation_info`\n",
" # defined by the dd.Series.fillna API itself.\n",
" func = dd.Series.fillna\n",
" layer.creation_info = {\"func\": func}\n",
" df.dask.layers[k] = layer\n",
"_patch_layers(ddf)\n",
"ddf.dask.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d87874a4-da55-491c-9db5-57c81fdd70c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Extract full graph and target layer for regeneration\n",
"dsk = ddf.dask\n",
"layer = dsk.layers[ddf._name]\n",
"\n",
"# Regenerate the collection\n",
"ddf2 = layer._regenerate_collection(dsk)\n",
"dd.utils.assert_eq(ddf, ddf2)"
]
},
{
"cell_type": "markdown",
"id": "3de92e4e-d179-4be6-a9fc-0d0c36e41f4e",
"metadata": {},
"source": [
"## Proposal 2: Expand Dispatching to Enable Filter Extraction\n",
"\n",
"In order to enable predicate pushdown for general filter operations. I propose that we expand the current `Dispatch` machinery to support the dispatching of functions based on a specific `callable`-object argument (e.g. `operator.gt`). By adding something like an optional `on_literal=True` argument ot `Dispatch`, we make it possible to register DNF-filter expressions for functions typically used for filtering operations. With such machinery in place, filter extraction can be accomplished by adding a recursive `_dnf_filter_expression` method to `Layer` and `Blockwise`:\n",
"\n",
"- `_dnf_filter_expression`\n",
" - **description**:\n",
" - Return a disjunctive normal form (DNF)-formatted filter expression for the graph terminating at this layer\n",
" - **Inputs**:\n",
" - `dsk` (`HighLevelGraph`): The full graph (layers and dependencies) for the original collection.\n",
" - **Default behavior**:\n",
" - `raise ValueError`"
]
},
{
"cell_type": "markdown",
"id": "995a0577-a80c-4f15-989c-8a8b5341045a",
"metadata": {},
"source": [
"#### Sample Implementation\n",
"\n",
"##### `Dispatch` Changes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "67b25b99-e88f-48bd-bba1-2868eb88a916",
"metadata": {},
"outputs": [],
"source": [
"class Dispatch:\n",
" \"\"\"Simple single dispatch.\"\"\"\n",
"\n",
" def __init__(self, name=None, on_literal=False):\n",
" self._lookup = {}\n",
" self._lazy = {}\n",
" if name:\n",
" self.__name__ = name\n",
" self.on_literal = on_literal\n",
"\n",
" def register(self, ref, func=None):\n",
" \"\"\"Register dispatch of `func` on arguments of type `type`\"\"\"\n",
"\n",
" def wrapper(func):\n",
" if isinstance(ref, tuple):\n",
" for t in ref:\n",
" self.register(t, func)\n",
" else:\n",
" self._lookup[ref] = func\n",
" return func\n",
"\n",
" return wrapper(func) if func is not None else wrapper\n",
"\n",
" def register_lazy(self, toplevel, func=None):\n",
" \"\"\"\n",
" Register a registration function which will be called if the\n",
" *toplevel* module (e.g. 'pandas') is ever loaded.\n",
" \"\"\"\n",
"\n",
" def wrapper(func):\n",
" self._lazy[toplevel] = func\n",
" return func\n",
"\n",
" return wrapper(func) if func is not None else wrapper\n",
"\n",
" def dispatch(self, ref):\n",
" \"\"\"Return the function implementation for the given ``cls``\"\"\"\n",
" lk = self._lookup\n",
" if self.on_literal:\n",
" try:\n",
" impl = lk[ref]\n",
" except KeyError:\n",
" pass\n",
" else:\n",
" return impl\n",
" raise ValueError(f\"No dispatch registered for {ref}\")\n",
" else:\n",
" cls = ref\n",
" for cls2 in cls.__mro__:\n",
" try:\n",
" impl = lk[cls2]\n",
" except KeyError:\n",
" pass\n",
" else:\n",
" if cls is not cls2:\n",
" # Cache lookup\n",
" lk[cls] = impl\n",
" return impl\n",
" # Is a lazy registration function present?\n",
" toplevel, _, _ = cls2.__module__.partition(\".\")\n",
" try:\n",
" register = self._lazy.pop(toplevel)\n",
" except KeyError:\n",
" pass\n",
" else:\n",
" register()\n",
" return self.dispatch(cls) # recurse\n",
" raise TypeError(f\"No dispatch for {ref}\")\n",
"\n",
" def __call__(self, arg, *args, **kwargs):\n",
" \"\"\"\n",
" Call the corresponding method based on type of argument.\n",
" \"\"\"\n",
" if self.on_literal:\n",
" meth = self.dispatch(arg)\n",
" else:\n",
" meth = self.dispatch(type(arg))\n",
" return meth(arg, *args, **kwargs)"
]
},
{
"cell_type": "markdown",
"id": "2a522953-dacc-45ce-b677-6de0852ad31d",
"metadata": {},
"source": [
"##### Relevant `Dispatch` Implementation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "875c497a-f7b2-4895-8229-9af4f8c3e46a",
"metadata": {},
"outputs": [],
"source": [
"dnf_filter_dispatch = Dispatch(\"dnf_filter_dispatch\", on_literal=True)\n",
"\n",
"_comparison_symbols = {\n",
" operator.eq: \"==\",\n",
" operator.ne: \"!=\",\n",
" operator.lt: \"<\",\n",
" operator.le: \"<=\",\n",
" operator.gt: \">\",\n",
" operator.ge: \">=\",\n",
"}\n",
"\n",
"def _get_blockwise_input(input_index, indices, dsk):\n",
" key = indices[input_index][0]\n",
" if indices[input_index][1] is None:\n",
" return key\n",
" return dsk.layers[key]._dnf_filter_expression(dsk)\n",
"\n",
"@dnf_filter_dispatch.register(tuple(_comparison_symbols.keys()))\n",
"def comparison_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" left = _get_blockwise_input(0, indices, dsk)\n",
" right = _get_blockwise_input(1, indices, dsk)\n",
" return (left, _comparison_symbols[op], right)\n",
"\n",
"@dnf_filter_dispatch.register((operator.and_, operator.or_))\n",
"def logical_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" left = _get_blockwise_input(0, indices, dsk)\n",
" right = _get_blockwise_input(1, indices, dsk)\n",
" if op == operator.or_:\n",
" return [left], [right]\n",
" elif op == operator.and_:\n",
" return (left, right)\n",
" else:\n",
" raise ValueError\n",
"\n",
"@dnf_filter_dispatch.register(operator.getitem)\n",
"def getitem_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" # Return dnf of key (selected by getitem)\n",
" key = _get_blockwise_input(1, indices, dsk)\n",
" return key\n",
"\n",
"@dnf_filter_dispatch.register(dd.Series.fillna)\n",
"def fillna_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" # Return dnf of input collection\n",
" return _get_blockwise_input(0, indices, dsk)"
]
},
{
"cell_type": "markdown",
"id": "c421bbee-e6bd-47d8-9a30-54abc38ce6a0",
"metadata": {},
"source": [
"##### `_dnf_filter_expression` Implementations"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "98d0b3ac-2e4c-45a0-badd-21ad46cc4697",
"metadata": {},
"outputs": [],
"source": [
"def _layer_dnf_filter_expression(self, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" \"\"\"Return a DNF-formatted filter expression for the\n",
" graph terminating at this layer\n",
" \"\"\"\n",
" # Default Implementation returns TypeError\n",
" raise TypeError(f\"No DNF dispatching implemented for {self.__class__}\")\n",
"\n",
"dask.highlevelgraph.Layer._dnf_filter_expression = _layer_dnf_filter_expression"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "44b9f077-355a-403d-8f53-c6d6e2743522",
"metadata": {},
"outputs": [],
"source": [
"def _blockwise_dnf_filter_expression(self, dsk: dask.highlevelgraph.HighLevelGraph):\n",
" \"\"\"Return a DNF-formatted filter expression for the\n",
" graph terminating at this layer\n",
" \"\"\"\n",
" if self._regenerable:\n",
" return dnf_filter_dispatch(\n",
" self.creation_info[\"func\"],\n",
" self.indices,\n",
" dsk,\n",
" )\n",
" return ValueError(f\"DNF dispatching requires `_regenerable==True`\")\n",
"\n",
"dask.blockwise.Blockwise._dnf_filter_expression = _blockwise_dnf_filter_expression"
]
},
{
"cell_type": "markdown",
"id": "d1c0372a-1a03-4afc-9b74-e292219aac55",
"metadata": {},
"source": [
"### Example Workflow"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9919e34c-7358-4ca8-8f38-2b8ad1b5d054",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('b', '>', 5), ('b', '<', 10)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Extract full graph and target layer for the filter\n",
"dsk = ddf.dask\n",
"layer = dsk.layers[ddf._name]\n",
"\n",
"# Regenerate the collection\n",
"filters = layer._dnf_filter_expression(dsk)\n",
"list(filters)"
]
},
{
"cell_type": "markdown",
"id": "8e6f2e75-eb14-4817-a4cd-6f6d895313ab",
"metadata": {},
"source": [
"## Proposal 3: Implement Eager Predicate Pushdown"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2e9f79cd-7df4-48d4-a6f0-328d29f96835",
"metadata": {},
"outputs": [],
"source": [
"def predicate_pushdown(ddf):\n",
" \n",
" # Get output layer name and HLG\n",
" name = ddf._name\n",
" dsk = ddf.dask\n",
"\n",
" # Extract filters\n",
" try:\n",
" filters = list(dsk.layers[name]._dnf_filter_expression(dsk))\n",
" except (TypeError, ValueError):\n",
" # DNF dispatching failed for 1+ layers\n",
" return ddf\n",
" \n",
" # We were able to extract a DNF filter expression.\n",
" # Check that all layers are regenerable, and that\n",
" # the graph contains an IO layer with filters support.\n",
" # All layers besides the root IO layer should also\n",
" # support DNF dispatching. Otherwise, there could be\n",
" # something like column-assignment or data manipulation\n",
" # between the IO layer and the filter.\n",
" io_layer = []\n",
" for k, v in dsk.layers.items():\n",
" if not v._regenerable:\n",
" return ddf\n",
" if (\n",
" # Real Logic should check: isinstance(v, DataFrameIOLayer)\n",
" v.creation_info[\"func\"] == dd.read_parquet and\n",
" \"filters\" in v.creation_info.get(\"kwargs\", {}) and\n",
" v.creation_info[\"kwargs\"][\"filters\"] is None\n",
" ):\n",
" io_layer.append(k)\n",
" else:\n",
" try:\n",
" dnf_filter_dispatch.dispatch(v.creation_info[\"func\"])\n",
" except (TypeError, ValueError):\n",
" # This is NOT an IO layer OR a filter-safe layer\n",
" return ddf\n",
" if len(io_layer) != 1:\n",
" return ddf\n",
" io_layer = io_layer.pop()\n",
" \n",
" # Regenerate collection with filtered IO layer\n",
" return dsk.layers[name]._regenerate_collection(\n",
" dsk,\n",
" new_kwargs={io_layer: {\"filters\": filters}},\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "90507895-780f-475a-8203-c1fda80c09b7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>2</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>3</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1</td>\n",
" <td>9</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a b\n",
"6 1 6\n",
"7 2 7\n",
"8 3 8\n",
"9 1 9"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Regenerate the collection\n",
"new = predicate_pushdown(ddf)\n",
"new.compute()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f7ce796f-8b6c-443d-9d4e-a76268ff98fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original filters: None\n"
]
}
],
"source": [
"# Check that the old `ddf` DataFrame has `filters=None`\n",
"print(\n",
" \"Original filters:\",\n",
" hlg_layer(ddf.dask, \"read-parquet\").creation_info[\"kwargs\"][\"filters\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9c7bc6ed-ffd1-46f3-a27d-4cb25a726904",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New filters: [('b', '>', 5), ('b', '<', 10)]\n"
]
}
],
"source": [
"# Check that the `new` DataFrame has the expected `filters`\n",
"print(\n",
" \"New filters:\",\n",
" hlg_layer(new.dask, \"read-parquet\").creation_info[\"kwargs\"][\"filters\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11942dea-04df-4583-a46d-97fe9cdf0d46",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment