Skip to content

Instantly share code, notes, and snippets.

@ebenolson
Created March 27, 2015 23:01
Show Gist options
  • Save ebenolson/1682625dc9823e27d771 to your computer and use it in GitHub Desktop.
Save ebenolson/1682625dc9823e27d771 to your computer and use it in GitHub Desktop.
Functions to draw Lasagne networks with graphviz, like Caffe's draw_net.py
"""
Functions to create network diagrams from a list of Layers.
Examples:
Draw a minimal diagram to a pdf file:
layers = lasagne.layers.get_all_layers(output_layer)
draw_to_file(layers, 'network.pdf', output_shape=False)
Draw a verbose diagram in an IPython notebook:
from IPython.display import Image #needed to render in notebook
layers = lasagne.layers.get_all_layers(output_layer)
dot = get_pydot_graph(layers, verbose=True)
return Image(dot.create_png())
"""
import pydot
def get_hex_color(layer_type):
"""
Determines the hex color for a layer. Some classes are given
default values, all others are calculated pseudorandomly
from their name.
:parameters:
- layer_type : string
Class name of the layer
:returns:
- color : string containing a hex color.
:usage:
>>> color = get_hex_color('MaxPool2DDNN')
'#9D9DD2'
"""
if 'Input' in layer_type:
return '#A2CECE'
if 'Conv' in layer_type:
return '#7C9ABB'
if 'Dense' in layer_type:
return '#6CCF8D'
if 'Pool' in layer_type:
return '#9D9DD2'
else:
return '#{0:x}'.format(hash(layer_type) % 2**24)
def get_pydot_graph(layers, output_shape=True, verbose=False):
"""
Creates a PyDot graph of the network defined by the given layers.
:parameters:
- layers : list
List of the layers, as obtained from lasange.layers.get_all_layers
- output_shape: (default `True`)
If `True`, the output shape of each layer will be displayed.
- verbose: (default `False`)
If `True`, layer attributes like filter shape, stride, etc.
will be displayed.
- verbose:
:returns:
- pydot_graph : PyDot object containing the graph
"""
pydot_graph = pydot.Dot('Network', graph_type='digraph')
pydot_nodes = {}
pydot_edges = []
for i, layer in enumerate(layers):
layer_type = '{0}'.format(layer.__class__.__name__)
key = repr(layer)
label = layer_type
color = get_hex_color(layer_type)
if verbose:
for attr in ['num_filters', 'num_units', 'ds',
'filter_shape', 'stride', 'strides', 'p']:
if hasattr(layer, attr):
label += '\n' + \
'{0}: {1}'.format(attr, getattr(layer, attr))
if hasattr(layer, 'nonlinearity'):
try:
nonlinearity = layer.nonlinearity.__name__
except AttributeError:
nonlinearity = layer.nonlinearity.__class__.__name__
label += '\n' + 'nonlinearity: {0}'.format(nonlinearity)
if output_shape:
label += '\n' + \
'Output shape: {0}'.format(layer.get_output_shape())
pydot_nodes[key] = pydot.Node(key,
label=label,
shape='record',
fillcolor=color,
style='filled',
)
if hasattr(layer, 'input_layers'):
for input_layer in layer.input_layers:
pydot_edges.append([repr(input_layer), key])
if hasattr(layer, 'input_layer'):
pydot_edges.append([repr(layer.input_layer), key])
for node in pydot_nodes.values():
pydot_graph.add_node(node)
for edge in pydot_edges:
pydot_graph.add_edge(
pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
return pydot_graph
def draw_to_file(layers, filename, **kwargs):
"""
Draws a network diagram to a file
:parameters:
- layers : list
List of the layers, as obtained from lasange.layers.get_all_layers
- filename: string
The filename to save output to.
- **kwargs: see docstring of get_pydot_graph for other options
"""
dot = get_pydot_graph(layers, **kwargs)
ext = filename[filename.rfind('.') + 1:]
with open(filename, 'w') as fid:
fid.write(dot.create(format=ext))
def draw_to_notebook(layers, **kwargs):
"""
Draws a network diagram in an IPython notebook
:parameters:
- layers : list
List of the layers, as obtained from lasange.layers.get_all_layers
- **kwargs: see docstring of get_pydot_graph for other options
"""
from IPython.display import Image # needed to render in notebook
dot = get_pydot_graph(layers, **kwargs)
return Image(dot.create_png())
@NSavov
Copy link

NSavov commented Jul 18, 2018

Many thanks to the author! For newer versions of Lasagne use the parameter output_shape instead of get_output_shape().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment