Skip to content

Instantly share code, notes, and snippets.

@cashlo
Created July 25, 2019 01:40
Show Gist options
  • Save cashlo/281f62e7eab3450ae1f74323e199e0c3 to your computer and use it in GitHub Desktop.
Save cashlo/281f62e7eab3450ae1f74323e199e0c3 to your computer and use it in GitHub Desktop.
from scipy.io import loadmat
import numpy as np
import scipy.optimize as opt
import matplotlib.pyplot as plt
def sigmoid(z):
return 1/(1+np.exp(-z))
def costFunctionReg(theta, X, y, lmbda):
m = len(y)
temp1 = np.multiply(y, np.log(sigmoid(np.dot(X, theta))))
temp2 = np.multiply(1-y, np.log(1-sigmoid(np.dot(X, theta))))
return np.sum(temp1 + temp2) / (-m) + np.sum(theta[1:]**2) * lmbda / (2*m)
def gradRegularization(theta, X, y, lmbda):
m = len(y)
temp = sigmoid(np.dot(X, theta)) - y
temp = np.dot(temp.T, X).T / m + theta * lmbda / m
temp[0] = temp[0] - theta[0] * lmbda / m
return temp
data = loadmat('ex3/ex3data1.mat')
X = data['X']
y = data['y']
indices = np.random.permutation(X.shape[0])
training_indices, test_indices = indices[:100], indices[4500:]
train_X, test_X = X[training_indices], X[test_indices]
train_y, test_y = y[training_indices], y[test_indices]
m = len(train_y)
ones = np.ones((m,1))
train_X = np.hstack((ones, train_X)) #add the intercept
test_X = np.hstack((np.ones((len(test_X),1)), test_X)) #add the intercept
(m,n) = train_X.shape
lmbda = 0.1
k = 10
theta = np.zeros((k,n)) #inital parameters
for i in range(k):
digit_class = i if i else 10
theta[i] = opt.fmin_cg( f = costFunctionReg,
x0 = theta[i],
fprime = gradRegularization,
args = (train_X, (train_y == digit_class).flatten(), lmbda),
maxiter = 100)
_, axarr = plt.subplots(1,10,figsize=(10,10))
for i in range(10):
axarr[i].imshow(theta[i][1:].reshape((20,20), order = 'F'))
axarr[i].axis('off')
plt.show()
pred = np.argmax(train_X @ theta.T, axis = 1)
pred = [e if e else 10 for e in pred]
print(np.mean(pred == train_y.flatten()) * 100)
pred = np.argmax(test_X @ theta.T, axis = 1)
pred = [e if e else 10 for e in pred]
print(np.mean(pred == test_y.flatten()) * 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment