Skip to content

Instantly share code, notes, and snippets.

@nbeuchat
Forked from craffel/draw_neural_net.py
Last active January 22, 2018 17:04
Show Gist options
  • Save nbeuchat/091c458327c39a84ba06e8686c76dfd5 to your computer and use it in GitHub Desktop.
Save nbeuchat/091c458327c39a84ba06e8686c76dfd5 to your computer and use it in GitHub Desktop.
Draw a neural network diagram with matplotlib!
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Modified on Mon Oct 10 23:29:41 2016
@author: craffel, edited by nbeuchat (Nicolas Beuchat)
"""
import matplotlib.pyplot as plt
def draw_neural_net(layer_sizes, ax=None, left=.1, right=.9, bottom=.1, top=.9,color='w'):
'''
Draw a neural network cartoon using matplotilb.
:usage:
>>> fig = plt.figure(figsize=(12, 12))
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
:parameters:
- layer_sizes : list of int
List of layer sizes, including input and output dimensionality
- ax : matplotlib.axes.AxesSubplot
The axes on which to plot the cartoon (get e.g. by plt.gca()). Default: gca
- left : float
The center of the leftmost node(s) will be placed here. Default = 0.1
- right : float
The center of the rightmost node(s) will be placed here. Default = 0.9
- bottom : float
The center of the bottommost node(s) will be placed here. Default = 0.1
- top : float
The center of the topmost node(s) will be placed here. Default = 0.9
- color: string or array of string or array of array of int
The color of the nodes (layer by layer)
Example:
color='k' -> black neurons
color=['r','k','b'] -> red input layer, black hidden layer, blue output layer
color=['r',[0.3,0.23,0.6],'g'] -> can use RGB value as well
'''
n_layers = len(layer_sizes)
v_spacing = (top - bottom)/float(max(layer_sizes))
h_spacing = (right - left)/float(len(layer_sizes) - 1)
c = color
if ax is None:
ax = plt.gca()
# Nodes
for n, layer_size in enumerate(layer_sizes):
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
for m in range(layer_size):
if len(color) > 1:
c = color[n]
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
color=c, ec='k', zorder=4)
ax.add_artist(circle)
# Edges
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
for m in range(layer_size_a):
for o in range(layer_size_b):
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k')
ax.add_artist(line)
# Beautify the axes
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment