It's common to want to convert something Dask already does, and convert it to use high level graph under the hood.
First you need to find the place in the code where the dask task dictionary is created.
Typically this looks like a variable called dsk
or dsk_out
that is a dictionary mapping the key names to individual tasks.
Found it? Great, this is the spot we're going to insert an instance of your (new, not yet created) high level graph class, eg: dsk_out = MyNewLayer(input_args, ...)
Now you need to make the new high level graph layer class. If should inherit from the abstract Layer
class (eg: class MyNewLayer(Layer):
). Currently, the code for high level graph layer classes lives in the file dask/layers.py
.
The most important method is _construct_graph
. This is the method that generates the task dictionary. As a start, you can copying across the logic from the spot you found in step 1. Later you'll need to make changes to actually get the benefits of using high level graphs, but for now you just need the code to run.
To get the code to run, you'll need to add a lot more boilerplate code. This includes methods is_materialized
, get_output_keys
and __len__
. Right now we're not going to worry about whether they materializes the task graph or not, we can tackle that problem in step 4.
Note: It is encouraged to cache the result from _construct_graph
, so that it doesn't recomupte things unnecessarily. This can be done by storing the result in a private class attribute (eg: self._dict
) that you have initialized as None in the __init__
function. Then at the start of the _construct_graph
method, it can check if there is a result stored here and return that if so, and otherwise generate the task dictionary from scratch.
Run the Dask tests with pytest to check you haven't broken things. Things are almost definitely broken. Fix the broken things.
To get the benefits of using high level graph, you have to avoid materializing the task graph until the very last stage when .compute()
is called.
That means method like get_output_keys
, __len__
, etc. must be forbidden from materializing the task graph. It often requires some rethinking on how to implement the logic but still delay most of the work.
During development, it can be handy to raise a RuntimeError at the beginning of the _construct_graph
method - that way you will be warned if you accidentally materialize the task graph.
Implement the cull
method. This is one of the most important parts.
Serialization