Last active
January 22, 2016 06:28
-
-
Save yosemitebandit/67dfca9a864b8dba64f8 to your computer and use it in GitHub Desktop.
looking at gradient descent with logistic regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
N = 100 | |
H = N/2 | |
D = 2 | |
# Generate a random point cloud. | |
X = np.random.randn(N, D) | |
# Create two classes with this data, centering the first class at (2, 2) and | |
# the second at (-2, -2). | |
X[H:, :] = X[H:, :] + 2*np.ones((H, D)) | |
X[:H, :] = X[:H, :] - 2*np.ones((H, D)) | |
# Initialize some targets (the actual classes). We are basically saying that | |
# the "0" class is data centered around (2, 2) and the "1" class is data around | |
# (-2, -2). | |
T = np.array([0]*H + [1]*H) | |
# Set up the input data in matrix form. | |
ones = np.array([[1] * N]).T | |
Xi = np.concatenate((ones, X), axis=1) | |
# Initialize the weights. | |
w = np.random.randn(D + 1) | |
# Calculate the model output. | |
z = Xi.dot(w) | |
def sigmoid(z): | |
return 1 / (1 + np.exp(-z)) | |
Y = sigmoid(z) | |
def cross_entropy(T, Y): | |
E = 0 | |
for i in xrange(N): | |
if T[i] == 1: | |
E -= np.log(Y[i]) | |
else: | |
E -= np.log(1 - Y[i]) | |
return E | |
print 'cross entropy with random weights: %s' % cross_entropy(T, Y) | |
# We can also look at the closed form solution, where the weights can be | |
# analytically calculated (see notes). | |
''' | |
w = np.array([0, 4, 4]) | |
z = Xi.dot(w) | |
Y = sigmoid(z) | |
print 'cross entropy with calculated weights: %s' % cross_entropy(T, Y) | |
''' | |
# Now apply gradient descent. | |
learning_rate = 0.1 | |
regularization = 0.1 | |
for i in xrange(1000): | |
if i % 10 == 0: | |
print cross_entropy(T, Y) | |
w += learning_rate * np.dot((T - Y).T, Xi) - regularization * w | |
Y = sigmoid(Xi.dot(w)) | |
print 'final weight: %s' % w |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment