Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Teaching a Machine to XOR (visualize learning process)
# Training a neural XOR circuit
# In Marvin Minsky Seymour Papert's in/famous critique of perceptrons () published in 1969, they argued that neural networks
# had extremely limited utility, proving that the perceptrons of the time could not even learn
# the exclusive OR function. This played some role
# Now we can easily teach a neural network an XOR function by incorporating more layers.
# Truth table:
# Input | Output
# 00 | 0
# 01 | 1
# 10 | 1
# 11 | 0
#
# I used craffel's draw_neural_net.py at https://gist.github.com/craffel/2d727968c3aaebd10359
#
#
# Date 2017/01/22
# www.thescinder.com
# Blog post https://thescinder.com/2017/01/24/teaching-a-machine-to-love-xor/
# Imports
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Input vector based on the truth table above
a0 = np.array([[0, 0],[0, 1],[1, 0],[1,1]])
#print(np.shape(a0))
# Target output
y = np.array([[0],[1],[1],[0]])
#print(np.shape(y))
def draw_neural_net(ax, left, right, bottom, top, layer_sizes,Theta0,Theta1):
'''
Public Gist from craffel
https://gist.github.com/craffel/2d727968c3aaebd10359
Draw a neural network cartoon using matplotilb.
:usage:
>>> fig = plt.figure(figsize=(12, 12))
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
:parameters:
- ax : matplotlib.axes.AxesSubplot
The axes on which to plot the cartoon (get e.g. by plt.gca())
- left : float
The center of the leftmost node(s) will be placed here
- right : float
The center of the rightmost node(s) will be placed here
- bottom : float
The center of the bottommost node(s) will be placed here
- top : float
The center of the topmost node(s) will be placed here
- layer_sizes : list of int
List of layer sizes, including input and output dimensionality
'''
n_layers = len(layer_sizes)
v_spacing = (top - bottom)/float(max(layer_sizes))
h_spacing = (right - left)/float(len(layer_sizes) - 1)
# Nodes
for n, layer_size in enumerate(layer_sizes):
layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
for m in range(layer_size):
circle = plt.Circle((n*h_spacing + left, layer_top - m*v_spacing), v_spacing/4.,
color='#999999', ec='k', zorder=4)
ax.add_artist(circle)
# Edges
for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
for m in range(layer_size_a):
for o in range(layer_size_b):
if (n == 0):
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='#8888dd',lw=Theta0[m,o])
elif (n == 1):
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='#8888cc',lw=Theta1[m,o])
else:
line = plt.Line2D([n*h_spacing + left, (n + 1)*h_spacing + left],
[layer_top_a - m*v_spacing, layer_top_b - o*v_spacing], c='r')
ax.add_artist(line)
# Neuron functions
def sigmoid(x):
#The sigmoid function is 0.5 at 0, ~1 at infinity and ~0 at -infinity
#This is a good activation function for neural networks
mySig = 1/(1+np.exp(-x))
return mySig
def sigmoidGradient(x):
#used for calculating the gradient at NN nodes to back-propagate during NN learning.
myDer = sigmoid(x)*(1-sigmoid(x))
return myDer
# Initialize neural network connections with random values. This breaks symmetry so the network can learn.
np.random.seed(3)
Theta0 = 2*np.random.random((2,3))-1
Theta1 = 2*np.random.random((3,1))-1
# Train the network
myEpochs = 25000
m = np.shape(a0)[0]
# J is a vector we'll use to keep track of the error function as we learn
J = []
#set the learning rate
lR = 1e-1
#This is a weight penalty that keeps the
myLambda = 0#3e-2
fig, ax = plt.subplots(1,1,figsize=(12,12))
#plt.close(fig2)
fig2,ax2 = plt.subplots(1,2,figsize=(12,4))
for j in range(myEpochs):
# Forward propagation
z1 = np.dot(a0,Theta0)
a1 = sigmoid(z1)
#print(np.shape(a1))
z2 = np.dot(a1,Theta1)
a2 = sigmoid(z2)
# The error
E = (y-a2)
J.append(np.mean(np.abs(E)))
# Back propagation
d2 = E.T
d1 = np.dot(Theta1,d2) * sigmoidGradient(z1.T)
Delta1 = 0*Theta1
Delta0 = 0*Theta0
for c in range(m):
Delta1 = Delta1 + np.dot(np.array([a1[c,:]]).T,np.array([d2[:,c]]))
Delta0 = Delta0 + np.dot(np.array([a0[c,:]]).T,np.array([d1[:,c]]))
w8Loss1 = myLambda * Theta1
w8Loss0 = myLambda * Theta0
#print(np.mean(Theta3))
Theta1Grad = Delta1/m + w8Loss1
Theta0Grad = Delta0/m + w8Loss0
Theta1 = Theta1 + Theta1Grad * lR #+ stoch1 * stochMultiplier
Theta0 = Theta0 + Theta0Grad * lR #+ stoch0 * stochMultiplier
if (j % 250 == 0):
#Save frames from the learning session
matplotlib.rcParams.update({'figure.titlesize': 42})
matplotlib.rcParams.update({'axes.titlesize': 24})
draw_neural_net(fig.gca(), .1,.9,.1,.9,[np.shape(a0)[1],np.shape(a1)[1],np.shape(a2)[1]],Theta0,Theta1)
#plt.figure(2)
plt.show()
fig.suptitle('Neural Network Iteration'+str(j))
fig.savefig('./trainingTLXOR/XORNN'+str(j))
fig.clf()
#fig2.subplot(121)
#plt.hold(True)
#plt.close(fig2)
ax2[0].plot(J,ls='-')#,çolor='#2222ee')
#plt
#plt.hold(True)
ax2[0].plot(j,J[j],'o')#,çolor='#1111ff')
#plt
ax2[0].axis([0,25000,0,0.6])
#fig2.suptitle('Mean Error')
#fig2.subplot(122)
#plt
ax2[1].plot([1],[1],'o',ms=32*a2[0,0])#,color='b',ms=a2[0,0])
#plt
ax2[1].plot([2],[1],'o',ms=32*a2[1,0])#,color='b',ms=a2[0,1])
#plt
ax2[1].plot([3],[1],'o',ms=32*a2[2,0])#,color='b',ms=a2[0,2])
#plt
ax2[1].plot([4],[1],'o',ms=32*a2[3,0])#,color='b',ms=a2[0,3])
ax2[1].axis([0,5,0,2])
ax2[0].set_title('Mean Error')
ax2[1].set_title('Outputs')
#suptitle('Mean Error and Output Vector')
plt.show()
fig2.savefig('./trainingTLXOR/XORNer'+str(j))
ax2[0].cla()
ax2[1].cla()
plt.close(fig2)
#ax2.cla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.