Skip to content

Instantly share code, notes, and snippets.

@yosemitebandit
Last active January 22, 2016 06:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yosemitebandit/114a977bbdd2bd6c7a12 to your computer and use it in GitHub Desktop.
Save yosemitebandit/114a977bbdd2bd6c7a12 to your computer and use it in GitHub Desktop.
addressing the donut problem in logistic regression
import numpy as np
import matplotlib.pyplot as plt
N = 1000
H = N/2
D = 2
R_inner = 5
R_outer = 10
# Half the data spread around some inner radius, the other half at another
# radius.
R1 = np.random.randn(H) + R_inner
theta = 2*np.pi * np.random.randn(H)
# Convert polar to x/y.
X_inner = np.concatenate([[R1 * np.cos(theta)], [R1 * np.sin(theta)]]).T
R2 = np.random.randn(H) + R_outer
theta = 2*np.pi * np.random.randn(H)
X_outer = np.concatenate([[R2 * np.cos(theta)], [R2 * np.sin(theta)]]).T
# Create the full input dataset and apply some class labels.
X = np.concatenate([X_inner, X_outer])
T = np.array([0]*(H) + [1]*(H))
# No line can separate these two classes..so maybe not a good candidate for
# logistic regression.
plt.scatter(X[:,0], X[:,1], c=T)
plt.show()
# Now let's try to handle the donut problem..
# Create a col of ones for the bias term.
ones = np.array([[1] * N]).T
# And create another col which is the radius of the point.
r = np.zeros((N, 1))
for i in xrange(N):
r[i] = np.sqrt(X[i,:].dot(X[i,:]))
# Now the inputs have more dimensions (so this should be reflected when we init
# the weights.
Xi = np.concatenate((ones, r, X), axis=1)
w = np.random.rand(D + 2)
z = Xi.dot(w)
def sigmoid(z):
return 1 / (1 + np.exp(-z))
Y = sigmoid(z)
def cross_entropy(T, Y):
E = 0
for i in xrange(N):
if T[i] == 1:
E -= np.log(Y[i])
else:
E -= np.log(1 - Y[i])
return E
learning_rate = 0.0001
regularization = 0.01
error = []
for i in xrange(5000):
e = cross_entropy(T, Y)
error.append(e)
if i % 100 == 0:
print e
w += learning_rate * (np.dot((T - Y).T, Xi) - regularization*w)
Y = sigmoid(Xi.dot(w))
plt.plot(error)
plt.title('cross entropy')
plt.show()
print 'final w:', w
print 'final classification rate', 1 - np.abs(T - np.round(Y)).sum() / N
# This will reveal that, as constructed, the points don't much matter in the
# classification (weights for these will be near zero), whereas the radius
# will, naturally, be the determiner.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment