Skip to content

Instantly share code, notes, and snippets.

@OriaGr
Last active August 20, 2016 22:39
Show Gist options
  • Save OriaGr/060d29279df5e91cff7271c2edf3fb9e to your computer and use it in GitHub Desktop.
Save OriaGr/060d29279df5e91cff7271c2edf3fb9e to your computer and use it in GitHub Desktop.
#related blog post - https://oriamathematics.wordpress.com/2016/08/21/multivariate-logistic-regression-with-example-mnist/
import numpy as np
import matplotlib.pyplot as plt
def predictProbabilities(X,W):
return np.exp(np.dot(X, W))/np.sum(np.exp(np.dot(X, W)), axis=1)[:, None]#sum by columns
def logLikelihood(X, Y, W):
probabilities = predictProbabilities(X, W)
probsOfActualLabels = np.sum(Y * probabilities, axis=1) #Element-wise multiplication, sum by columns
return np.sum(np.log(probsOfActualLabels))
def gradient(X, Y, W):
return np.dot(X.T, Y - predictProbabilities(X, W))
def predictLabels(X, W): #turns the predicted probabilities into onehot matrix
probabilities = predictProbabilities(X, W)
m, k = probabilities.shape
predictedLabels = np.zeros((m, k))
for i in range(0, m):
maxIndex = np.argmax(probabilities[i, :])
predictedLabels[i, maxIndex] = 1
return predictedLabels
def successRate (X, Y, W):
predictedLabels = predictLabels(X, W)
m = Y.shape[0]
correctCounter = 0
for i in range(0, m):
predictedLabel = np.argmax(predictedLabels[i, :])
realLabel = np.argmax(Y[i, :])
if(predictedLabel == realLabel):
correctCounter += 1
return 100.0 * correctCounter / m
trX = np.load("trXMNIST.npy")
trX = np.concatenate((trX, np.ones((trX.shape[0], 1))), axis=1)
trY = np.load("trYMNIST.npy")
teX = np.load("teXMNIST.npy")
teX = np.concatenate((teX, np.ones((teX.shape[0], 1))), axis=1)
teY = np.load("teYMNIST.npy")
n = trX.shape[1]
k = trY.shape[1]
W = np.zeros((n, k))
numOfIterations = 500
learningRate = 0.000001
costArray = np.zeros((numOfIterations, 1))
for i in range(0, numOfIterations):
costArray[i, 0] = logLikelihood(trX, trY, W)
W = W + learningRate * gradient(trX, trY, W)
plt.plot(costArray)
plt.show()
print("train success rate is")
print("%lf" %(successRate(trX, trY, W)))
print("test success rate is")
print("%lf" %(successRate(teX, teY, W)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment