import matplotlib.pyplot as plt | |
def draw_neural_net(ax, left, right, bottom, top, layer_sizes): | |
''' | |
Draw a neural network cartoon using matplotilb. | |
:usage: | |
>>> fig = plt.figure(figsize=(12, 12)) | |
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2]) | |
:parameters: | |
- ax : matplotlib.axes.AxesSubplot | |
The axes on which to plot the cartoon (get e.g. by plt.gca()) | |
- left : float | |
The center of the leftmost node(s) will be placed here | |
- right : float | |
The center of the rightmost node(s) will be placed here | |
- bottom : float | |
The center of the bottommost node(s) will be placed here | |
- top : float | |
The center of the topmost node(s) will be placed here | |
- layer_sizes : list of int | |
List of layer sizes, including input and output dimensionality | |
''' | |
n_layers = len(layer_sizes) | |
v_spacing = (top - bottom)/float(max(layer_sizes)) | |
h_spacing = (right - left)/float(len(layer_sizes) - 1) | |
# Nodes | |
for n, layer_size in enumerate(layer_sizes): | |
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2. | |
for m in xrange(layer_size): | |
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4., | |
color='w', ec='k', zorder=4) | |
ax.add_artist(circle) | |
# Edges | |
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): | |
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2. | |
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2. | |
for m in xrange(layer_size_a): | |
for o in xrange(layer_size_b): | |
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left], | |
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='k') | |
ax.add_artist(line) |
This comment has been minimized.
This comment has been minimized.
@bamos and thanks to you for the example :)! |
This comment has been minimized.
This comment has been minimized.
How can I write something into a neuron? |
This comment has been minimized.
This comment has been minimized.
@craffel Very useful, thank you! |
This comment has been minimized.
This comment has been minimized.
Great job. Thanks. |
This comment has been minimized.
This comment has been minimized.
Thanks a lot. Besides, based on your smart codes, I attempt to add texts of Inputs (X_1,X_2,...,H_1, H_2,..., y_1,y_2,...) and bias nodes(labeled as '1') and The function for plotting the network are as follows: #--------------------------------------------------------------------- filename: [draw_neural_net_.py]#---------------------------------------------------------------------
plt.plot(nh_spacing + left, layer_top - mv_spacing, 'o', mfc='w', mec='k', ls= '-', markersize = 40)
for m in xrange(layer_size_a):
#---------------------------------------------------------------------------------------------------------------------------------- |
This comment has been minimized.
This comment has been minimized.
The testing main program is: filename: [test_XOR_Classification.py]#-------------------------------------------------------------------- #--------[1] Input data (You should define the X_train and y_train#---------------------------------------------------- testing different topology#---------------------------------------- XOR_MLP = MLP( plot the neural network#----------------------------------------------------------- fig66 = plt.figure(figsize=(12, 12)) draw_neural_net(ax, .1, .9, .1, .9, [2, 2, 1], |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
The result hadn't been tuned to the best. Only for demonstrating the plotting network topology using sklearn and matplotlib in Python. |
This comment has been minimized.
This comment has been minimized.
Thanks for the script! |
This comment has been minimized.
This comment has been minimized.
Suppose that it can work in Python 2 and Python 3. Here, I employ Python 2.7 to test it. You can try to execute it. You also can modify the X_labels to be the real names (such as from 'X_1', 'X_2' to 'Sepal.Length', 'Sepal.Width'and from y_1, y_2, y_3 to 'Setosa', 'Versicolor' and 'Virginica' by adding some plot.text() commands. |
This comment has been minimized.
This comment has been minimized.
@ljhuang2017 do you have that file available somewhere else? looks like the formatting got kind of borked. |
This comment has been minimized.
This comment has been minimized.
@ljhuang2017 this looks very nice. Could you just make another gist out of it? Just edit your comment, copy your code and paste it in a new gist. This way people can use your code without the need of reformating it. |
This comment has been minimized.
This comment has been minimized.
I don't think people are notified when you @ them on Gists, and @ljhuang2017 doesn't have any contact information. Did anyone get their code working? Can you post it as your own Gist if so? |
This comment has been minimized.
This comment has been minimized.
Many thanks for the script! |
This comment has been minimized.
This comment has been minimized.
I was able to successfully run @ljhuang2017 code and posted on a new gist |
This comment has been minimized.
This comment has been minimized.
Thank you for the code, saved me a lot of time drawing it myself. |
This comment has been minimized.
This comment has been minimized.
How can i have labels for coefficients and intercepts. I need for demonstration. Labels(a1,b1,c1,d1 ) etc.. like this |
This comment has been minimized.
Thanks for the example @craffel!
In case anybody else wants a quick preview, here's the image the code produces.
I removed the axis with: