Last active
August 8, 2016 19:20
-
-
Save OriaGr/141a93760da981d8735296803788cf92 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Data used for this code is https://github.com/OriaGr/Blog-posts trXnew and trYnew | |
#Blog Post - https://oriamathematics.wordpress.com/2016/08/08/linear-regression-finale-multivariate-lr-with-real-example/ | |
import numpy as np | |
import numpy.matlib | |
import matplotlib.pyplot as plt | |
def predict(X, W): | |
return np.dot(X, W) | |
def gradient(X, Y, W, regTerm=0): | |
m, k = Y.shape | |
n, k = W.shape | |
return (-np.dot(X.T, Y) + np.dot(np.dot(X.T,X),W))/(m*k) + regTerm * W /(n*k) | |
def cost(X, Y, W, regTerm=0): | |
m, k = Y.shape | |
n, k = W.shape | |
Yhat = predict(X, W) | |
return np.trace(np.dot(Y-Yhat,(Y-Yhat).T))/(2*m*k) + regTerm * np.trace(np.dot(W,W.T)) / (2*n*k) | |
def Rsquared(X, Y, W): | |
m, k = Y.shape | |
SSres = cost(X, Y, W) | |
Ybar = np.mean(Y,axis=0) | |
Ybar = np.matlib.repmat(Ybar, m, 1) | |
SStot = np.trace(np.dot(Y-Ybar,(Y-Ybar).T)) | |
return 1-SSres/SStot | |
trX = np.load("trXnew.npy") | |
trY = np.load("trYnew.npy") | |
mTest = 15 | |
teX = trX[0: mTest, :] | |
teY = trY[0: mTest, :] | |
trX = trX[mTest + 1:, :] | |
trY = trY[mTest + 1:, :] | |
mTrain = trX.shape[0] | |
k = trY.shape[1] | |
trX = np.concatenate((trX, np.ones((mTrain, 1))),axis = 1) | |
teX = np.concatenate((teX, np.ones((mTest, 1))),axis = 1) | |
n = trX.shape[1] | |
W = np.random.rand(n, k) | |
numIter = 100 | |
learningRate = 0.000001 | |
costArray = np.zeros((numIter, 1)) | |
for i in range(0, numIter): | |
costArray[i, 0] = cost(trX, trY, W) | |
W = W - learningRate * gradient(trX, trY, W) | |
plt.plot(costArray) | |
plt.xlabel("Iteration") | |
plt.ylabel("Cost") | |
plt.show() | |
print("train Rsquared is %lf" %(Rsquared(trX, trY, W))) | |
print("test Rsquared is %lf" %(Rsquared(teX, teY, W))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment