Created
February 20, 2018 00:00
-
-
Save aminnj/9d185689bb835029cb4a46709ea77bea to your computer and use it in GitHub Desktop.
Less verbose keras network architecture drawing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import pydot | |
import keras | |
from keras.models import load_model | |
""" | |
brew install graphviz | |
pip install pydot | |
The keras format for drawing network architectures was too | |
verbose for me. I wanted to group droupouts and activations | |
with their parent layers, and make output shapes look nicer | |
""" | |
model = load_model("model_Feb18.h5") | |
def get_parents_names(layer): | |
inbound_layer_names = [] | |
for i, node in enumerate(layer._inbound_nodes): | |
for inbound_layer in node.inbound_layers: | |
inbound_layer_names.append(inbound_layer.name) | |
return inbound_layer_names | |
# store important information from keras layer objects | |
# into a list of dicts | |
layers = [] | |
for layer in model.layers: | |
classname = layer.__class__.__name__ | |
to_append = { | |
"name": layer.name, | |
"type": classname, | |
"input_shape": layer.input_shape, | |
"output_shape": layer.output_shape, | |
"parent_names": get_parents_names(layer), | |
} | |
if classname in ["Dropout"]: | |
to_append["f_dropout"] = layer.rate | |
if classname in ["Conv2D"]: | |
to_append["kernel_size"] = layer.kernel_size | |
if classname in ["MaxPooling2D"]: | |
to_append["pool_size"] = layer.pool_size | |
layers.append(to_append) | |
# make dict of dicts from list of dicts, linking | |
# parents to their children | |
d_layers = {} | |
for layer in layers: | |
d_layers[layer["name"]] = layer | |
d_layers[layer["name"]]["children_names"] = [ | |
lay["name"] for lay in layers | |
if layer["name"] in lay["parent_names"] | |
] | |
# get rid of intermediate layers that we don't care about | |
# but linking them to their parents, then relinking parents/children | |
for layer_name in d_layers.keys(): | |
layer = d_layers[layer_name] | |
layer_type = layer["type"] | |
d_layers[layer_name]["skip"] = False | |
# for every layer, if it's Dropout or Activation | |
if layer_type in ["Dropout","LeakyReLU"]: | |
for pname in layer["parent_names"]: | |
# 1) stick this layer's information in the parent layers | |
d_layers[pname][layer_type] = layer | |
# 2) remove this layer from the parent's children names, and add this | |
# layer's children to the parent's children list | |
d_layers[pname]["children_names"].remove(layer_name) | |
d_layers[pname]["children_names"].extend(layer["children_names"]) | |
for cname in layer["children_names"]: | |
# 3) remove this layer from the children's parent's names, and add this | |
# layer's parent names to the children's parent list | |
d_layers[cname]["parent_names"].remove(layer_name) | |
d_layers[cname]["parent_names"].extend(layer["parent_names"]) | |
# 4) flag this layer to skip it later | |
d_layers[layer_name]["skip"] = True | |
# draw the nodes and edges with pydot | |
dot = pydot.Dot() | |
dot.set('concentrate', False) | |
dot.set_node_defaults(shape='record') | |
for layer_name in d_layers.keys(): | |
layer = d_layers[layer_name] | |
if layer["skip"]: continue | |
output_shape_str = ",".join(str(size) for size in layer["output_shape"] if size is not None) | |
label = "{} ({})".format(layer["type"],output_shape_str) | |
if "Dropout" in layer: | |
label += "\nDropout({})".format(layer["Dropout"]["f_dropout"]) | |
if "LeakyReLU" in layer: | |
label += "\nLeakyReLU" | |
node = pydot.Node(layer_name, label=label) | |
for pname in layer["parent_names"]: | |
dot.add_edge(pydot.Edge(pname, layer_name)) | |
dot.add_node(node) | |
dot.write("output.png", format="png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment