Skip to content

Instantly share code, notes, and snippets.

@DylanDmitri
Created February 4, 2018 20:56
Show Gist options
  • Save DylanDmitri/25efa74ca8c23cb4c047be4d12405f07 to your computer and use it in GitHub Desktop.
Save DylanDmitri/25efa74ca8c23cb4c047be4d12405f07 to your computer and use it in GitHub Desktop.
def plot_keras_weights(model, outfile='model.png'):
graph = pydot.Dot(graph_type='digraph')
prev_layer = []
for i in range(model.input_shape[1]):
node = pydot.Node(f'Input[{i}]')
prev_layer.append(node)
graph.add_node(node)
for l,layer in enumerate(model.layers):
weight_matrix, biases = layer.get_weights()
bias = pydot.Node('1', color="blue")
graph.add_node(bias)
this_layer = []
for n, weights in enumerate(weight_matrix.T):
new_node = pydot.Node(f'neuron[{l}][{n}]')
this_layer.append(new_node)
graph.add_node(new_node)
graph.add_edge(pydot.Edge(bias, new_node, label=biases[n].round(3)))
for weight, parent in zip(weights, prev_layer):
graph.add_edge(pydot.Edge(parent,new_node,label=weight.round(3)))
prev_layer = this_layer
graph.write_png(outfile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment