Skip to content

Instantly share code, notes, and snippets.

@aminnj
Created February 20, 2018 00:00
Show Gist options
  • Save aminnj/9d185689bb835029cb4a46709ea77bea to your computer and use it in GitHub Desktop.
Save aminnj/9d185689bb835029cb4a46709ea77bea to your computer and use it in GitHub Desktop.
Less verbose keras network architecture drawing
import os
import pickle
import pydot
import keras
from keras.models import load_model
"""
brew install graphviz
pip install pydot
The keras format for drawing network architectures was too
verbose for me. I wanted to group droupouts and activations
with their parent layers, and make output shapes look nicer
"""
model = load_model("model_Feb18.h5")
def get_parents_names(layer):
inbound_layer_names = []
for i, node in enumerate(layer._inbound_nodes):
for inbound_layer in node.inbound_layers:
inbound_layer_names.append(inbound_layer.name)
return inbound_layer_names
# store important information from keras layer objects
# into a list of dicts
layers = []
for layer in model.layers:
classname = layer.__class__.__name__
to_append = {
"name": layer.name,
"type": classname,
"input_shape": layer.input_shape,
"output_shape": layer.output_shape,
"parent_names": get_parents_names(layer),
}
if classname in ["Dropout"]:
to_append["f_dropout"] = layer.rate
if classname in ["Conv2D"]:
to_append["kernel_size"] = layer.kernel_size
if classname in ["MaxPooling2D"]:
to_append["pool_size"] = layer.pool_size
layers.append(to_append)
# make dict of dicts from list of dicts, linking
# parents to their children
d_layers = {}
for layer in layers:
d_layers[layer["name"]] = layer
d_layers[layer["name"]]["children_names"] = [
lay["name"] for lay in layers
if layer["name"] in lay["parent_names"]
]
# get rid of intermediate layers that we don't care about
# but linking them to their parents, then relinking parents/children
for layer_name in d_layers.keys():
layer = d_layers[layer_name]
layer_type = layer["type"]
d_layers[layer_name]["skip"] = False
# for every layer, if it's Dropout or Activation
if layer_type in ["Dropout","LeakyReLU"]:
for pname in layer["parent_names"]:
# 1) stick this layer's information in the parent layers
d_layers[pname][layer_type] = layer
# 2) remove this layer from the parent's children names, and add this
# layer's children to the parent's children list
d_layers[pname]["children_names"].remove(layer_name)
d_layers[pname]["children_names"].extend(layer["children_names"])
for cname in layer["children_names"]:
# 3) remove this layer from the children's parent's names, and add this
# layer's parent names to the children's parent list
d_layers[cname]["parent_names"].remove(layer_name)
d_layers[cname]["parent_names"].extend(layer["parent_names"])
# 4) flag this layer to skip it later
d_layers[layer_name]["skip"] = True
# draw the nodes and edges with pydot
dot = pydot.Dot()
dot.set('concentrate', False)
dot.set_node_defaults(shape='record')
for layer_name in d_layers.keys():
layer = d_layers[layer_name]
if layer["skip"]: continue
output_shape_str = ",".join(str(size) for size in layer["output_shape"] if size is not None)
label = "{} ({})".format(layer["type"],output_shape_str)
if "Dropout" in layer:
label += "\nDropout({})".format(layer["Dropout"]["f_dropout"])
if "LeakyReLU" in layer:
label += "\nLeakyReLU"
node = pydot.Node(layer_name, label=label)
for pname in layer["parent_names"]:
dot.add_edge(pydot.Edge(pname, layer_name))
dot.add_node(node)
dot.write("output.png", format="png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment