Skip to content

Instantly share code, notes, and snippets.

@jcrist
Created March 15, 2017 22:51
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jcrist/dc5b7cedfddff123f2177e5238e566e5 to your computer and use it in GitHub Desktop.
Save jcrist/dc5b7cedfddff123f2177e5238e566e5 to your computer and use it in GitHub Desktop.
Simple visualization of dask graph pipelines.
import os
import graphviz
from dask.optimize import key_split
from dask.dot import _get_display_cls
from dask.core import get_dependencies
def node_key(s):
if isinstance(s, tuple):
return s[0]
return str(s)
def simple_vis(x, filename='simple', format=None, **kwargs):
if hasattr(x, 'dask'):
dsk = x._optimize(x.dask, x._keys())
else:
dsk = x
deps = {k: get_dependencies(dsk, k) for k in dsk}
g = graphviz.Digraph(graph_attr={'rankdir': 'LR'})
nodes = set()
edges = set()
for k in dsk:
key = node_key(k)
if key not in nodes:
g.node(key, label=key_split(k), shape='rectangle')
nodes.add(key)
for dep in deps[k]:
dep_key = node_key(dep)
if dep_key not in nodes:
g.node(dep_key, label=key_split(dep), shape='rectangle')
nodes.add(dep_key)
# Avoid circular references
if dep_key != key and (dep_key, key) not in edges:
g.edge(dep_key, key)
edges.add((dep_key, key))
fmts = ['.png', '.pdf', '.dot', '.svg', '.jpeg', '.jpg']
if format is None and any(filename.lower().endswith(fmt) for fmt in fmts):
filename, format = os.path.splitext(filename)
format = format[1:].lower()
if format is None:
format = 'png'
data = g.pipe(format=format)
if not data:
raise RuntimeError("Graphviz failed to properly produce an image. "
"This probably means your installation of graphviz "
"is missing png support. See: "
"https://github.com/ContinuumIO/anaconda-issues/"
"issues/485 for more information.")
display_cls = _get_display_cls(format)
if not filename:
return display_cls(data=data)
full_filename = '.'.join([filename, format])
with open(full_filename, 'wb') as f:
f.write(data)
return display_cls(filename=full_filename)
@jcrist
Copy link
Author

jcrist commented Mar 15, 2017

Example:

In [1]: from vis import simple_vis

In [2]: import numpy as np

In [3]: import dask.array as da

In [4]: x = np.arange(100).reshape((10, 10))

In [5]: dx = da.from_array(x, chunks=(5, 5))

In [6]: res = dx.dot(dx.T).sum(axis=1).mean() + dx.mean(axis=0) * 4

In [7]: simple_vis(res)
Out[7]: <IPython.core.display.Image object>

In [8]: res.visualize()
Out[8]: <IPython.core.display.Image object>

The full graph looks like:

mydask

While the simplified looks like:
simple

@sfrodrigues
Copy link

sfrodrigues commented May 8, 2018

Updated version:

import os
import graphviz

from dask.optimization import key_split
from dask.dot import _get_display_cls
from dask.core import get_dependencies


class SimpleComputationGraph:
    def __init__(self):
        return

    @staticmethod
    def _node_key(s):
        if isinstance(s, tuple):
            return s[0]
        return str(s)

    def simple_graph(self,
                     x,
                     filename='simple_computation_graph',
                     format=None):

        if hasattr(x, 'dask'):
            dsk = x.__dask_optimize__(x.dask, x.__dask_keys__())
        else:
            dsk = x

        deps = {k: get_dependencies(dsk, k) for k in dsk}

        g = graphviz.Digraph(graph_attr={'rankdir': 'LR'})

        nodes = set()
        edges = set()
        for k in dsk:
            key = self._node_key(k)
            if key not in nodes:
                g.node(key, label=key_split(k), shape='rectangle')
                nodes.add(key)
            for dep in deps[k]:
                dep_key = self._node_key(dep)
                if dep_key not in nodes:
                    g.node(dep_key, label=key_split(dep), shape='rectangle')
                    nodes.add(dep_key)
                # Avoid circular references
                if dep_key != key and (dep_key, key) not in edges:
                    g.edge(dep_key, key)
                    edges.add((dep_key, key))

        fmts = ['.png', '.pdf', '.dot', '.svg', '.jpeg', '.jpg']
        if format is None and any(filename.lower().endswith(fmt) for fmt in fmts):
            filename, format = os.path.splitext(filename)
            format = format[1:].lower()

        if format is None:
            format = 'png'

        data = g.pipe(format=format)
        if not data:
            raise RuntimeError("Graphviz failed to properly produce an image. "
                               "This probably means your installation of graphviz "
                               "is missing png support. See: "
                               "https://github.com/ContinuumIO/anaconda-issues/"
                               "issues/485 for more information.")

        display_cls = _get_display_cls(format)

        if not filename:
            return display_cls(data=data)

        full_filename = '.'.join([filename, format])
        with open(full_filename, 'wb') as f:
            f.write(data)

        return display_cls(filename=full_filename)

@LucHermitte
Copy link

Hi.
Thanks @jcrist and @sfrodrigues for these gists. What licence(s) would you use for them?

@jcrist
Copy link
Author

jcrist commented Jul 22, 2020

🤷 BSD 3-clause I guess. Note that if you're using a dask collection (e.g. dask.array or dask.dataframe) you can do obj.dask.visualize() to get a high-level view of a dask collection's graph.

@LucHermitte
Copy link

Alas I'm directly working with tasks and dask.visualize()doesn't produce something as neat as your codes do.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment