Last active
February 16, 2022 16:30
-
-
Save rjzamora/c56989efd5b0cfc13537f889f505d035 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "a0454883-d88d-4997-bafe-5abd4daa9134", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import dask\n", | |
"import dask.dataframe as dd\n", | |
"from dask.utils_test import hlg_layer\n", | |
"import operator" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6cf6c55b-d851-4db2-b1d8-a5c1eca1a701", | |
"metadata": {}, | |
"source": [ | |
"# Dask Design Proposal: General support for predicate pushdown\n", | |
"\n", | |
"**Author**: Rick Zamora (Last update: February 14th, 2022)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c3ee9ae-1cb0-44e9-8cfa-92025c238b0e", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"The purpose of the proposed changes is to enable certain `HighLevelGraph` (HLG) layers to store and utilize the necessary information for the \"pushdown\" of general filter expressions. HLG Layers are currently used to define the necessary task graph required for a Dask collection. However, information only flows in one direction between the collection API and the graph (that is: it flows from the collection API to the graph). This limitation is problematic when it comes to high-level optimizations like \"predicate pushdown,\" because the application of filters at the root IO layer can modify the initial partition count and index (`divisions`). In other words, predicate pushdown requires the entire graph to be completely regenerated.\n", | |
"\n", | |
"Given that the entire graph needs to be regenerated when a filter is pushed down into the IO layer, it seems likely that the best time to apply the optimization is as soon as a `getitem[Series]` operation is executed in the Dask-Dataframe API. For example [dask/dask#8633](https://github.com/dask/dask/pull/8633) shows how this approach can be used with the existing `Layer`/`Blockwise` implementations to support simple filters like `ddf[ddf[\"b\"] < 10]`. The primary purpose of this proposal is to simplify and generalize the brute-force logic currently used in that PR.\n", | |
"\n", | |
"In order to capture arbitrary filter expressions, I propose three seperate (but related) changes to Dask:\n", | |
"\n", | |
"- **Proposal 1**: Enable collection/graph regeneration\n", | |
" - Add new attributes to `Layer`/`Blockwise` to enable collection regeneration\n", | |
" - Expand/formalize the usage of `creation_info` in `Blockwise`\n", | |
"- **Proposal 2** Expand dispatching machinery and enable filter-expression extraction\n", | |
" - Add `on_literal=True` option to `Dispatch` to enable dispatching by the literal value of a function argument (rather than the type). This change allows both Dask a down-stream libraries to register filter expressions for specific callable objects (like `operator.gt`)\n", | |
" - Add new method to `Layer`/`Blockwise` to enable recursive filter-expression extraction\n", | |
"- **Poposal 3** Implement eager predicate pushdown\n", | |
" - Use the above changes to push down general filter operations to the IO layer when a `getitem[Series]` operation is executed in the Dask-Dataframe API" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b5daa9cb-f6b2-4428-84bc-ea9d5abc20ae", | |
"metadata": {}, | |
"source": [ | |
"## Proposal 1: Enable Collection Regeneration\n", | |
"\n", | |
"In order to enable predicate pushdown for general filter operations, it is clear that we need a simple but reliable mechanism for graph regeneration. Since the original graph is typically produced using a collection API, the simplest way to regenerate a graph is to \"replay\" the same collection-API functin calls. One straightforward way to accomplish this is to (optionally) define a `creation_info` attribute at the `Layer` level, storing the original function and arguments used to generate the Dask collection attached to that Layer. Note that this attribute is already defined within the `DataFrameIOLayer` implementation, but has yet to be leveraged for graph optimization purposes. This document proposes that we expand the usage of `creation_info`, by defining it for a variety of simple filtering-related `Blockwise` layers (e.g. `getitem`, and `gt`), and that the following attributes be added to `HighLevelGraph.Layer`:\n", | |
"\n", | |
"- `Layer._regenerable` (a property)\n", | |
" - **description**:\n", | |
" - This propery tells whether or not the layer is capable of regenerating itself by regenerating the corresponding Dask collection. If `True`, the `_regenerate_collection` method must be implemented.\n", | |
" - **Default behavior**:\n", | |
" - `return False`\n", | |
"- `Layer._regenerate_collection` (an abstract method)\n", | |
" - **description**:\n", | |
" - If `_regenerable` is `True`, this method should regenerate and return the appropriate Dask collection for the layer.\n", | |
" - **Inputs**:\n", | |
" - `dsk` (`HighLevelGraph`): The full graph (layers and dependencies) for the original collection.\n", | |
" - `new_kwargs` (`dict[dict]`); Optional: New (non-default) key-word argument to use for the regeneration of each Layer. Keys correspond to original Layer names, and the values correspond to a dictionary of key-word arguments to be used to update default values. For example, if Layer `\"read-parquet-86753\"` was originally generated with `read_parquet(...,filters=None,columns=[\"a\"])`, then `new_kwargs={\"filters\": [(\"a\", \"<\", 10)]}` will update the `filters` argument to `read_parquet`, but will leave the original `columns` argument.\n", | |
" - `_regen_cache` (`dict`): A (private) regeneration cache to be used in recursive calls to avoid redundant collection generation.\n", | |
" - **Default behavior**:\n", | |
" - `raise NotImplementedError`\n", | |
"\n", | |
"We define the `_regenerable` property in the base `Layer` class (rather than in `Blockwise`) so that we can ask **any** layer if it is \"regenerable\". Similarly, we define `_regenerate_collection` in `Layer` because non-`Blockwise` classes may want to implement this feature in the future.\n", | |
"\n", | |
"**NOTE**: Another alternative to this `Layer._regenerable` approach is to define a distinct class for \"regenerable\" layers. However, that approach will likely lead to a messy class-inheritance structure if/when non-`Blockwise` layers begin adopting collection-regeneration features. Therefore, my clear preference is to avoid something like a `RegenerableLayer` class." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c7976877-94d4-45b8-a362-39d4fb213372", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### Sample Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b0a7b9d9-276a-4f1d-82ad-d275c48df94f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _regenerable(self):\n", | |
" \"\"\"Whether this layer supports ``_regenerate_collection``\n", | |
"\n", | |
" If True, this layer can regenerate a Dask collection\n", | |
" when provided the required collection inputs.\n", | |
" \"\"\"\n", | |
" return False\n", | |
"\n", | |
"def _regenerate_collection(\n", | |
" self,\n", | |
" dsk: dask.highlevelgraph.HighLevelGraph,\n", | |
" new_kwargs: dict = None,\n", | |
" _regen_cache: dict = None,\n", | |
"):\n", | |
" \"\"\"Regenerate a Dask collection for this layer using the\n", | |
" provided inputs and key-word arguments\n", | |
" \"\"\"\n", | |
" if self._regenerable:\n", | |
" raise NotImplementedError\n", | |
" raise ValueError(\n", | |
" \"`_regenerate_collection` requires `_regenerable=True`\"\n", | |
" )\n", | |
"\n", | |
"# Monkey patch\n", | |
"setattr(dask.highlevelgraph.Layer, '_regenerable', property(_regenerable))\n", | |
"dask.highlevelgraph.Layer._regenerate_collection = _regenerate_collection" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1f9c78d8-fc2c-4ce1-8e27-e85c089cd52e", | |
"metadata": {}, | |
"source": [ | |
"### Related Blockwise Changes\n", | |
"\n", | |
"Since most IO and filtering operations in Dask-Dataframe consist entirely of `Blockwise`-based HLG layers, effective predicate pushdown clearly depends on the implementation of `_regenerable` and `_regenerate_collection` in `Blockwise`. Assuming that we adopt the `creation_info` attribute in `Blockwise` more generally, the necessary logic becomes quite simple:\n", | |
"\n", | |
"- `Blockwise._regenerable`: Return `True` if (and only if) a \"proper\" `creation_info` attribute is defined.\n", | |
"- `Blockwise._regenerate_collection`: Use `creation_info` to regenerate the Dask collection for the current `Blockwise` layer recursively. Since `Blockwise` keeps track of input collections and literals using an `indices` attribute, it is relatively straightforward for the layer to call `_regenerate_collection` on the required input layers before replaying its own `creation_info` function call. \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4496b01d-74c5-49f7-a1f1-408c1e8fde4f", | |
"metadata": {}, | |
"source": [ | |
"#### Sample Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "b0192fd4-c939-4bb2-b3f1-10ec2302222e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _regenerable(self):\n", | |
" # creation_info must contain the callable\n", | |
" # function used to generate the collection\n", | |
" # terminated by this Layer. The kwargs should\n", | |
" # also be saved if/when applicable. Positional\n", | |
" # args should be captured by `indices`\n", | |
" return \"func\" in self.creation_info\n", | |
"\n", | |
"def _regenerate_collection(\n", | |
" self,\n", | |
" dsk: dask.highlevelgraph.HighLevelGraph,\n", | |
" new_kwargs: dict = None,\n", | |
" _regen_cache: dict = None,\n", | |
"):\n", | |
"\n", | |
" # Return regenerated layer if the work was\n", | |
" # already done\n", | |
" _regen_cache = _regen_cache or {}\n", | |
" if self.output in _regen_cache:\n", | |
" return _regen_cache[self.output]\n", | |
"\n", | |
" # Check that this layer is `_regenerable`\n", | |
" if not self._regenerable:\n", | |
" raise ValueError(\n", | |
" \"`_regenerate_collection` requires `_regenerable=True`\"\n", | |
" )\n", | |
"\n", | |
" # Recursively generate necessary inputs to \n", | |
" # this layer to generate the collection\n", | |
" inputs = []\n", | |
" for key, ind in self.indices:\n", | |
" if ind is None:\n", | |
" if isinstance(key, (str, tuple)) and key in dsk.layers:\n", | |
" continue\n", | |
" inputs.append(key)\n", | |
" elif key in self.io_deps:\n", | |
" continue\n", | |
" elif dsk.layers[key]._regenerable:\n", | |
" inputs.append(\n", | |
" dsk.layers[key]._regenerate_collection(\n", | |
" dsk,\n", | |
" new_kwargs=new_kwargs,\n", | |
" _regen_cache=_regen_cache,\n", | |
" )\n", | |
" )\n", | |
" else:\n", | |
" raise ValueError(\n", | |
" \"`_regenerate_collection` failed. \"\n", | |
" \"Not all HLG layers are regenerable.\"\n", | |
" )\n", | |
"\n", | |
" # Extract the callable func and key-word args.\n", | |
" # Then return a regenerated collection\n", | |
" func = self.creation_info[\"func\"]\n", | |
" regen_kwargs = self.creation_info.get(\"kwargs\", {}).copy()\n", | |
" regen_kwargs.update((new_kwargs or {}).get(self.output, {}))\n", | |
" result = func(*inputs, **regen_kwargs)\n", | |
" _regen_cache[self.output] = result\n", | |
" return result\n", | |
"\n", | |
"# Monkey patch\n", | |
"setattr(dask.blockwise.Blockwise, '_regenerable', property(_regenerable))\n", | |
"dask.blockwise.Blockwise._regenerate_collection = _regenerate_collection" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3a9b624f-2a36-4ebe-8986-e895bfd39cf9", | |
"metadata": {}, | |
"source": [ | |
"### Example Workflow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "f8206395-1d09-449c-91d1-04b53aac3b75", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.50.0 (20220117.2223)\n", | |
" -->\n", | |
"<!-- Pages: 1 -->\n", | |
"<svg width=\"247pt\" height=\"476pt\"\n", | |
" viewBox=\"0.00 0.00 247.00 476.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472)\">\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-472 243,-472 243,4 -4,4\"/>\n", | |
"<!-- -8825203084944598416 -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>-8825203084944598416</title>\n", | |
"<g id=\"a_node1\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. Number of Partitions: 3 DataFrame Type: pandas 2 DataFrame Columns: ['a', 'b'] \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"150,-468 20,-468 20,-432 150,-432 150,-468\"/>\n", | |
"<text text-anchor=\"middle\" x=\"85\" y=\"-445\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem-#0</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -2633590387208578626 -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>-2633590387208578626</title>\n", | |
"<g id=\"a_node2\"><a xlink:title=\"A DataFrameIO Layer with 3 Tasks. Number of Partitions: 3 DataFrame Type: pandas 3 DataFrame Columns: ['a', 'b', 'c'] \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"184,-36 0,-36 0,0 184,0 184,-36\"/>\n", | |
"<text text-anchor=\"middle\" x=\"92\" y=\"-13\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">read-parquet-#1</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -287342118792325956 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>-287342118792325956</title>\n", | |
"<g id=\"a_node3\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. Number of Partitions: 3 DataFrame Type: pandas 2 DataFrame Columns: ['a', 'b'] \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"157,-108 27,-108 27,-72 157,-72 157,-108\"/>\n", | |
"<text text-anchor=\"middle\" x=\"92\" y=\"-85\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem-#2</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -2633590387208578626->-287342118792325956 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>-2633590387208578626->-287342118792325956</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M92,-36.3C92,-44.02 92,-53.29 92,-61.89\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"88.5,-61.9 92,-71.9 95.5,-61.9 88.5,-61.9\"/>\n", | |
"</g>\n", | |
"<!-- -287342118792325956->-8825203084944598416 -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>-287342118792325956->-8825203084944598416</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M81.81,-108.16C76.17,-118.29 69.42,-131.54 65,-144 51.48,-182.12 46,-192.55 46,-233 46,-233 46,-233 46,-307 46,-348.44 61.64,-394.33 73.11,-422.42\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"70,-424.07 77.11,-431.93 76.45,-421.35 70,-424.07\"/>\n", | |
"</g>\n", | |
"<!-- -7716256153655819261 -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>-7716256153655819261</title>\n", | |
"<g id=\"a_node5\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"204,-180 74,-180 74,-144 204,-144 204,-180\"/>\n", | |
"<text text-anchor=\"middle\" x=\"139\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">getitem-#4</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -287342118792325956->-7716256153655819261 -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>-287342118792325956->-7716256153655819261</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M103.62,-108.3C109.14,-116.53 115.85,-126.52 121.93,-135.58\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"119.04,-137.54 127.52,-143.9 124.85,-133.64 119.04,-137.54\"/>\n", | |
"</g>\n", | |
"<!-- 8140518894306660026 -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>8140518894306660026</title>\n", | |
"<g id=\"a_node4\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"175.5,-324 74.5,-324 74.5,-288 175.5,-288 175.5,-324\"/>\n", | |
"<text text-anchor=\"middle\" x=\"125\" y=\"-301\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">and_-#3</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- 9026215552911565517 -->\n", | |
"<g id=\"node8\" class=\"node\">\n", | |
"<title>9026215552911565517</title>\n", | |
"<g id=\"a_node8\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"175.5,-396 74.5,-396 74.5,-360 175.5,-360 175.5,-396\"/>\n", | |
"<text text-anchor=\"middle\" x=\"125\" y=\"-373\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">fillna-#7</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- 8140518894306660026->9026215552911565517 -->\n", | |
"<g id=\"edge9\" class=\"edge\">\n", | |
"<title>8140518894306660026->9026215552911565517</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M125,-324.3C125,-332.02 125,-341.29 125,-349.89\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"121.5,-349.9 125,-359.9 128.5,-349.9 121.5,-349.9\"/>\n", | |
"</g>\n", | |
"<!-- 5538263221748953944 -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>5538263221748953944</title>\n", | |
"<g id=\"a_node6\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"155,-252 81,-252 81,-216 155,-216 155,-252\"/>\n", | |
"<text text-anchor=\"middle\" x=\"118\" y=\"-229\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">gt-#5</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -7716256153655819261->5538263221748953944 -->\n", | |
"<g id=\"edge7\" class=\"edge\">\n", | |
"<title>-7716256153655819261->5538263221748953944</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M133.81,-180.3C131.47,-188.1 128.65,-197.49 126.05,-206.17\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"122.65,-205.31 123.13,-215.9 129.36,-207.32 122.65,-205.31\"/>\n", | |
"</g>\n", | |
"<!-- -8776560892563676883 -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>-8776560892563676883</title>\n", | |
"<g id=\"a_node7\"><a xlink:title=\"A Blockwise Layer with 3 Tasks. \">\n", | |
"<polygon fill=\"none\" stroke=\"black\" points=\"239,-252 173,-252 173,-216 239,-216 239,-252\"/>\n", | |
"<text text-anchor=\"middle\" x=\"206\" y=\"-229\" font-family=\"Helvetica,sans-Serif\" font-size=\"20.00\">lt-#6</text>\n", | |
"</a>\n", | |
"</g>\n", | |
"</g>\n", | |
"<!-- -7716256153655819261->-8776560892563676883 -->\n", | |
"<g id=\"edge8\" class=\"edge\">\n", | |
"<title>-7716256153655819261->-8776560892563676883</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M155.56,-180.3C163.68,-188.78 173.59,-199.14 182.47,-208.42\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"180.19,-211.09 189.63,-215.9 185.24,-206.25 180.19,-211.09\"/>\n", | |
"</g>\n", | |
"<!-- 5538263221748953944->8140518894306660026 -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>5538263221748953944->8140518894306660026</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M119.73,-252.3C120.5,-260.02 121.43,-269.29 122.29,-277.89\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"118.81,-278.29 123.29,-287.9 125.78,-277.6 118.81,-278.29\"/>\n", | |
"</g>\n", | |
"<!-- -8776560892563676883->8140518894306660026 -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>-8776560892563676883->8140518894306660026</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M185.98,-252.3C175.87,-261.03 163.46,-271.76 152.48,-281.25\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"150.07,-278.71 144.79,-287.9 154.65,-284.01 150.07,-278.71\"/>\n", | |
"</g>\n", | |
"<!-- 9026215552911565517->-8825203084944598416 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>9026215552911565517->-8825203084944598416</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M115.11,-396.3C110.51,-404.36 104.94,-414.11 99.85,-423.02\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"96.7,-421.48 94.77,-431.9 102.77,-424.95 96.7,-421.48\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.graphs.Digraph at 0x7f1b76286af0>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Example workflow with filtering after read_parquet\n", | |
"ddf = dd.read_parquet(\"./tmpdir\")[[\"a\", \"b\"]]\n", | |
"filters = operator.and_(\n", | |
" ddf[\"b\"]>5,\n", | |
" ddf[\"b\"]<10,\n", | |
").fillna(False)\n", | |
"ddf = ddf[filters]\n", | |
"\n", | |
"# Patch layers to define \"proper\" `creation_info`\n", | |
"def _patch_layers(df):\n", | |
" for k, layer in df.dask.copy().layers.items():\n", | |
" if hasattr(layer, \"creation_info\"):\n", | |
" # This is the read_parquet layer.\n", | |
" # Need to move the `path` arg to be in \"kwargs\"\n", | |
" path = layer.creation_info[\"args\"][0]\n", | |
" kwargs = {\"path\": path}\n", | |
" kwargs.update(layer.creation_info[\"kwargs\"])\n", | |
" layer.creation_info[\"kwargs\"] = kwargs\n", | |
" else:\n", | |
" # This is not the IO layer.\n", | |
" # Need to define `creation_info`\n", | |
" func = layer.dsk[k][0]\n", | |
" if func == dask.utils.apply:\n", | |
" func = layer.dsk[k][1]\n", | |
" if (\n", | |
" isinstance(func, dask.utils.methodcaller) and\n", | |
" func.func == \"fillna\"\n", | |
" ):\n", | |
" # `methodcaller` makes \"patching\" tricky.\n", | |
" # In practice, we would certainly want `creation_info`\n", | |
" # defined by the dd.Series.fillna API itself.\n", | |
" func = dd.Series.fillna\n", | |
" layer.creation_info = {\"func\": func}\n", | |
" df.dask.layers[k] = layer\n", | |
"_patch_layers(ddf)\n", | |
"ddf.dask.visualize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "d87874a4-da55-491c-9db5-57c81fdd70c9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Extract full graph and target layer for regeneration\n", | |
"dsk = ddf.dask\n", | |
"layer = dsk.layers[ddf._name]\n", | |
"\n", | |
"# Regenerate the collection\n", | |
"ddf2 = layer._regenerate_collection(dsk)\n", | |
"dd.utils.assert_eq(ddf, ddf2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3de92e4e-d179-4be6-a9fc-0d0c36e41f4e", | |
"metadata": {}, | |
"source": [ | |
"## Proposal 2: Expand Dispatching to Enable Filter Extraction\n", | |
"\n", | |
"In order to enable predicate pushdown for general filter operations. I propose that we expand the current `Dispatch` machinery to support the dispatching of functions based on a specific `callable`-object argument (e.g. `operator.gt`). By adding something like an optional `on_literal=True` argument ot `Dispatch`, we make it possible to register DNF-filter expressions for functions typically used for filtering operations. With such machinery in place, filter extraction can be accomplished by adding a recursive `_dnf_filter_expression` method to `Layer` and `Blockwise`:\n", | |
"\n", | |
"- `_dnf_filter_expression`\n", | |
" - **description**:\n", | |
" - Return a disjunctive normal form (DNF)-formatted filter expression for the graph terminating at this layer\n", | |
" - **Inputs**:\n", | |
" - `dsk` (`HighLevelGraph`): The full graph (layers and dependencies) for the original collection.\n", | |
" - **Default behavior**:\n", | |
" - `raise ValueError`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "995a0577-a80c-4f15-989c-8a8b5341045a", | |
"metadata": {}, | |
"source": [ | |
"#### Sample Implementation\n", | |
"\n", | |
"##### `Dispatch` Changes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "67b25b99-e88f-48bd-bba1-2868eb88a916", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Dispatch:\n", | |
" \"\"\"Simple single dispatch.\"\"\"\n", | |
"\n", | |
" def __init__(self, name=None, on_literal=False):\n", | |
" self._lookup = {}\n", | |
" self._lazy = {}\n", | |
" if name:\n", | |
" self.__name__ = name\n", | |
" self.on_literal = on_literal\n", | |
"\n", | |
" def register(self, ref, func=None):\n", | |
" \"\"\"Register dispatch of `func` on arguments of type `type`\"\"\"\n", | |
"\n", | |
" def wrapper(func):\n", | |
" if isinstance(ref, tuple):\n", | |
" for t in ref:\n", | |
" self.register(t, func)\n", | |
" else:\n", | |
" self._lookup[ref] = func\n", | |
" return func\n", | |
"\n", | |
" return wrapper(func) if func is not None else wrapper\n", | |
"\n", | |
" def register_lazy(self, toplevel, func=None):\n", | |
" \"\"\"\n", | |
" Register a registration function which will be called if the\n", | |
" *toplevel* module (e.g. 'pandas') is ever loaded.\n", | |
" \"\"\"\n", | |
"\n", | |
" def wrapper(func):\n", | |
" self._lazy[toplevel] = func\n", | |
" return func\n", | |
"\n", | |
" return wrapper(func) if func is not None else wrapper\n", | |
"\n", | |
" def dispatch(self, ref):\n", | |
" \"\"\"Return the function implementation for the given ``cls``\"\"\"\n", | |
" lk = self._lookup\n", | |
" if self.on_literal:\n", | |
" try:\n", | |
" impl = lk[ref]\n", | |
" except KeyError:\n", | |
" pass\n", | |
" else:\n", | |
" return impl\n", | |
" raise ValueError(f\"No dispatch registered for {ref}\")\n", | |
" else:\n", | |
" cls = ref\n", | |
" for cls2 in cls.__mro__:\n", | |
" try:\n", | |
" impl = lk[cls2]\n", | |
" except KeyError:\n", | |
" pass\n", | |
" else:\n", | |
" if cls is not cls2:\n", | |
" # Cache lookup\n", | |
" lk[cls] = impl\n", | |
" return impl\n", | |
" # Is a lazy registration function present?\n", | |
" toplevel, _, _ = cls2.__module__.partition(\".\")\n", | |
" try:\n", | |
" register = self._lazy.pop(toplevel)\n", | |
" except KeyError:\n", | |
" pass\n", | |
" else:\n", | |
" register()\n", | |
" return self.dispatch(cls) # recurse\n", | |
" raise TypeError(f\"No dispatch for {ref}\")\n", | |
"\n", | |
" def __call__(self, arg, *args, **kwargs):\n", | |
" \"\"\"\n", | |
" Call the corresponding method based on type of argument.\n", | |
" \"\"\"\n", | |
" if self.on_literal:\n", | |
" meth = self.dispatch(arg)\n", | |
" else:\n", | |
" meth = self.dispatch(type(arg))\n", | |
" return meth(arg, *args, **kwargs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2a522953-dacc-45ce-b677-6de0852ad31d", | |
"metadata": {}, | |
"source": [ | |
"##### Relevant `Dispatch` Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "875c497a-f7b2-4895-8229-9af4f8c3e46a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dnf_filter_dispatch = Dispatch(\"dnf_filter_dispatch\", on_literal=True)\n", | |
"\n", | |
"_comparison_symbols = {\n", | |
" operator.eq: \"==\",\n", | |
" operator.ne: \"!=\",\n", | |
" operator.lt: \"<\",\n", | |
" operator.le: \"<=\",\n", | |
" operator.gt: \">\",\n", | |
" operator.ge: \">=\",\n", | |
"}\n", | |
"\n", | |
"def _get_blockwise_input(input_index, indices, dsk):\n", | |
" key = indices[input_index][0]\n", | |
" if indices[input_index][1] is None:\n", | |
" return key\n", | |
" return dsk.layers[key]._dnf_filter_expression(dsk)\n", | |
"\n", | |
"@dnf_filter_dispatch.register(tuple(_comparison_symbols.keys()))\n", | |
"def comparison_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" left = _get_blockwise_input(0, indices, dsk)\n", | |
" right = _get_blockwise_input(1, indices, dsk)\n", | |
" return (left, _comparison_symbols[op], right)\n", | |
"\n", | |
"@dnf_filter_dispatch.register((operator.and_, operator.or_))\n", | |
"def logical_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" left = _get_blockwise_input(0, indices, dsk)\n", | |
" right = _get_blockwise_input(1, indices, dsk)\n", | |
" if op == operator.or_:\n", | |
" return [left], [right]\n", | |
" elif op == operator.and_:\n", | |
" return (left, right)\n", | |
" else:\n", | |
" raise ValueError\n", | |
"\n", | |
"@dnf_filter_dispatch.register(operator.getitem)\n", | |
"def getitem_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" # Return dnf of key (selected by getitem)\n", | |
" key = _get_blockwise_input(1, indices, dsk)\n", | |
" return key\n", | |
"\n", | |
"@dnf_filter_dispatch.register(dd.Series.fillna)\n", | |
"def fillna_dnf(op, indices: list, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" # Return dnf of input collection\n", | |
" return _get_blockwise_input(0, indices, dsk)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c421bbee-e6bd-47d8-9a30-54abc38ce6a0", | |
"metadata": {}, | |
"source": [ | |
"##### `_dnf_filter_expression` Implementations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "98d0b3ac-2e4c-45a0-badd-21ad46cc4697", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _layer_dnf_filter_expression(self, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" \"\"\"Return a DNF-formatted filter expression for the\n", | |
" graph terminating at this layer\n", | |
" \"\"\"\n", | |
" # Default Implementation returns TypeError\n", | |
" raise TypeError(f\"No DNF dispatching implemented for {self.__class__}\")\n", | |
"\n", | |
"dask.highlevelgraph.Layer._dnf_filter_expression = _layer_dnf_filter_expression" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "44b9f077-355a-403d-8f53-c6d6e2743522", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _blockwise_dnf_filter_expression(self, dsk: dask.highlevelgraph.HighLevelGraph):\n", | |
" \"\"\"Return a DNF-formatted filter expression for the\n", | |
" graph terminating at this layer\n", | |
" \"\"\"\n", | |
" if self._regenerable:\n", | |
" return dnf_filter_dispatch(\n", | |
" self.creation_info[\"func\"],\n", | |
" self.indices,\n", | |
" dsk,\n", | |
" )\n", | |
" return ValueError(f\"DNF dispatching requires `_regenerable==True`\")\n", | |
"\n", | |
"dask.blockwise.Blockwise._dnf_filter_expression = _blockwise_dnf_filter_expression" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d1c0372a-1a03-4afc-9b74-e292219aac55", | |
"metadata": {}, | |
"source": [ | |
"### Example Workflow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "9919e34c-7358-4ca8-8f38-2b8ad1b5d054", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('b', '>', 5), ('b', '<', 10)]" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Extract full graph and target layer for the filter\n", | |
"dsk = ddf.dask\n", | |
"layer = dsk.layers[ddf._name]\n", | |
"\n", | |
"# Regenerate the collection\n", | |
"filters = layer._dnf_filter_expression(dsk)\n", | |
"list(filters)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8e6f2e75-eb14-4817-a4cd-6f6d895313ab", | |
"metadata": {}, | |
"source": [ | |
"## Proposal 3: Implement Eager Predicate Pushdown" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "2e9f79cd-7df4-48d4-a6f0-328d29f96835", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def predicate_pushdown(ddf):\n", | |
" \n", | |
" # Get output layer name and HLG\n", | |
" name = ddf._name\n", | |
" dsk = ddf.dask\n", | |
"\n", | |
" # Extract filters\n", | |
" try:\n", | |
" filters = list(dsk.layers[name]._dnf_filter_expression(dsk))\n", | |
" except (TypeError, ValueError):\n", | |
" # DNF dispatching failed for 1+ layers\n", | |
" return ddf\n", | |
" \n", | |
" # We were able to extract a DNF filter expression.\n", | |
" # Check that all layers are regenerable, and that\n", | |
" # the graph contains an IO layer with filters support.\n", | |
" # All layers besides the root IO layer should also\n", | |
" # support DNF dispatching. Otherwise, there could be\n", | |
" # something like column-assignment or data manipulation\n", | |
" # between the IO layer and the filter.\n", | |
" io_layer = []\n", | |
" for k, v in dsk.layers.items():\n", | |
" if not v._regenerable:\n", | |
" return ddf\n", | |
" if (\n", | |
" # Real Logic should check: isinstance(v, DataFrameIOLayer)\n", | |
" v.creation_info[\"func\"] == dd.read_parquet and\n", | |
" \"filters\" in v.creation_info.get(\"kwargs\", {}) and\n", | |
" v.creation_info[\"kwargs\"][\"filters\"] is None\n", | |
" ):\n", | |
" io_layer.append(k)\n", | |
" else:\n", | |
" try:\n", | |
" dnf_filter_dispatch.dispatch(v.creation_info[\"func\"])\n", | |
" except (TypeError, ValueError):\n", | |
" # This is NOT an IO layer OR a filter-safe layer\n", | |
" return ddf\n", | |
" if len(io_layer) != 1:\n", | |
" return ddf\n", | |
" io_layer = io_layer.pop()\n", | |
" \n", | |
" # Regenerate collection with filtered IO layer\n", | |
" return dsk.layers[name]._regenerate_collection(\n", | |
" dsk,\n", | |
" new_kwargs={io_layer: {\"filters\": filters}},\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "90507895-780f-475a-8203-c1fda80c09b7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>a</th>\n", | |
" <th>b</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>1</td>\n", | |
" <td>6</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>2</td>\n", | |
" <td>7</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>3</td>\n", | |
" <td>8</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>1</td>\n", | |
" <td>9</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" a b\n", | |
"6 1 6\n", | |
"7 2 7\n", | |
"8 3 8\n", | |
"9 1 9" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Regenerate the collection\n", | |
"new = predicate_pushdown(ddf)\n", | |
"new.compute()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "f7ce796f-8b6c-443d-9d4e-a76268ff98fc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Original filters: None\n" | |
] | |
} | |
], | |
"source": [ | |
"# Check that the old `ddf` DataFrame has `filters=None`\n", | |
"print(\n", | |
" \"Original filters:\",\n", | |
" hlg_layer(ddf.dask, \"read-parquet\").creation_info[\"kwargs\"][\"filters\"],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "9c7bc6ed-ffd1-46f3-a27d-4cb25a726904", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"New filters: [('b', '>', 5), ('b', '<', 10)]\n" | |
] | |
} | |
], | |
"source": [ | |
"# Check that the `new` DataFrame has the expected `filters`\n", | |
"print(\n", | |
" \"New filters:\",\n", | |
" hlg_layer(new.dask, \"read-parquet\").creation_info[\"kwargs\"][\"filters\"],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "11942dea-04df-4583-a46d-97fe9cdf0d46", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment