Skip to content

Instantly share code, notes, and snippets.

@GenevieveBuckley
Created December 10, 2021 09:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GenevieveBuckley/cd8f8f756e338e1cd621741d3f67b490 to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/cd8f8f756e338e1cd621741d3f67b490 to your computer and use it in GitHub Desktop.
Steps to implement a new high level graph in Dask

So, you want to make a new high level graph layer class?

Step 1:

It's common to want to convert something Dask already does, and convert it to use high level graph under the hood.

First you need to find the place in the code where the dask task dictionary is created. Typically this looks like a variable called dsk or dsk_out that is a dictionary mapping the key names to individual tasks.

Found it? Great, this is the spot we're going to insert an instance of your (new, not yet created) high level graph class, eg: dsk_out = MyNewLayer(input_args, ...)

Step 2:

Now you need to make the new high level graph layer class. If should inherit from the abstract Layer class (eg: class MyNewLayer(Layer):). Currently, the code for high level graph layer classes lives in the file dask/layers.py.

The most important method is _construct_graph. This is the method that generates the task dictionary. As a start, you can copying across the logic from the spot you found in step 1. Later you'll need to make changes to actually get the benefits of using high level graphs, but for now you just need the code to run.

To get the code to run, you'll need to add a lot more boilerplate code. This includes methods is_materialized, get_output_keys and __len__. Right now we're not going to worry about whether they materializes the task graph or not, we can tackle that problem in step 4.

Note: It is encouraged to cache the result from _construct_graph, so that it doesn't recomupte things unnecessarily. This can be done by storing the result in a private class attribute (eg: self._dict) that you have initialized as None in the __init__ function. Then at the start of the _construct_graph method, it can check if there is a result stored here and return that if so, and otherwise generate the task dictionary from scratch.

Step 3:

Run the Dask tests with pytest to check you haven't broken things. Things are almost definitely broken. Fix the broken things.

Step 4:

To get the benefits of using high level graph, you have to avoid materializing the task graph until the very last stage when .compute() is called.

That means method like get_output_keys, __len__, etc. must be forbidden from materializing the task graph. It often requires some rethinking on how to implement the logic but still delay most of the work.

During development, it can be handy to raise a RuntimeError at the beginning of the _construct_graph method - that way you will be warned if you accidentally materialize the task graph.

Step 5:

Implement the cull method. This is one of the most important parts.

Step 6:

Serialization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment