Skip to content

Instantly share code, notes, and snippets.

@mrtzh
Created March 1, 2016 06:01
Show Gist options
  • Save mrtzh/266c37d3a274376134a6 to your computer and use it in GitHub Desktop.
Save mrtzh/266c37d3a274376134a6 to your computer and use it in GitHub Desktop.
"""Visualize stability of stochastic gradient descent for finding a linear
separator."""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
np.random.seed(1)
mpl.rcParams['axes.linewidth'] = 0.0
# create point clouds
num_points = 30
bluepts = np.random.normal(0,1,(num_points, 2))
redpts = np.random.normal(0,1,(num_points,2))
bluepts[:,0] -= 1.1*np.ones((num_points,))
bluepts[0,:] = -np.ones(2)
redpts[:,0] += 1.1*np.ones((num_points,))
bluepts2 = np.copy(bluepts)
bluepts2[0,:] = np.ones(2)
# compute one pass of SGD over the data
sgdpts = np.ones((num_points+1,2))
sgdpts[0,:] = np.array([0,1.0])
sgdpts2 = np.copy(sgdpts)
step_size = 1.0/num_points
for i in xrange(num_points):
grad = np.multiply(sgdpts[i, :], bluepts[i,:]**2) + bluepts[i,:]
grad2 = np.multiply(sgdpts2[i, :], bluepts2[i,:]**2) + bluepts2[i,:]
sgdpts[i+1,:] = sgdpts[i, :] - step_size * grad
sgdpts2[i+1,:] = sgdpts2[i, :] - step_size * grad2
grad = np.multiply(sgdpts[i+1, :], redpts[i,:]**2) - redpts[i,:]
sgdpts[i+1,:] = sgdpts[i+1, :] - step_size * grad
sgdpts2[i+1,:] = sgdpts2[i+1, :] - step_size * grad
def hyperplane(w):
w /= np.linalg.norm(w)
x = np.linspace(-200,100,100)
y = -1.0 * w[0] * x / w[1]
return x, y
for i in xrange(num_points):
fig = plt.figure(figsize=(6,6), facecolor='white', frameon=False, dpi=300)
plt.plot(bluepts[1:,0],bluepts[1:,1],'bo', ms=8)
plt.plot(redpts[:,0],redpts[:,1],'r^', ms=8)
ax = plt.axes()
ax.arrow(-0.95, -0.95, 1.82, 1.82, head_width=0.05, fc='#cccccc', ec='#cccccc', lw=2)
plt.plot(bluepts[0,0],bluepts[0,0], marker='o', color='green', ms=8)
plt.plot(bluepts2[0,0],bluepts2[0,0], marker='o', color='purple', ms=8)
plt.xlim(-2,2)
plt.ylim(-2,2)
x, y = hyperplane(sgdpts[i,:])
plt.plot(x, y, 'g-')
x2, y2 = hyperplane(sgdpts2[i,:])
plt.plot(x2, y2, color='purple')
ax.fill_between(x, y, y+100, color=(0.9, 0.9, 0.9))
ax.fill_between(x, y, y2, color=(0.1, 0.5, 0.5))
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
plt.savefig(str(i).zfill(2)+'.png', bbox_inches='tight', pad_inches=0)
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment