Last active
August 5, 2016 16:56
-
-
Save OriaGr/843736ceb6b95a81a1fa1d03fbad1b1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Written by Oria Gruber | |
#Related blog post - https://oriamathematics.wordpress.com/2016/08/04/intro-to-machine-learning-linear-regression/ | |
import numpy as np | |
import matplotlib.pyplot as plt | |
#Code for blog post regarding Linear Regression | |
def predict(X, W): | |
return np.dot(X, W) | |
def cost(X, Y, W): | |
N = Y.shape[0] | |
predictions = predict(X, W) | |
return np.dot((Y-predictions).T,Y-predictions)/(2*N) | |
def gradient(X, Y, W): | |
N = Y.shape[0] | |
predictions = predict(X, W) | |
grad = np.zeros((2, 1)) #grad[0,0] = bias, grad[1,0] = slope | |
grad[0, 0] = -np.sum(Y - predictions)/N | |
grad[1, 0] = -np.dot(X[:, 1].T,Y - predictions)/N | |
return grad | |
def gradientDescent(X, Y, numOfIterations, learningRate): | |
W = np.random.rand(2, 1) | |
costArray = np.zeros((numOfIterations, 1)) | |
for i in range(0, numOfIterations): | |
costArray[i, 0] = cost(X, Y, W) | |
W = W - learningRate * gradient(X, Y, W) | |
return [W, costArray] | |
N = 50 | |
X = np.random.rand(N,2) | |
X[:, 0] = 1 #set first column to 1 in all rows | |
realWeights = np.random.rand(2, 1) | |
Y = np.dot(X, realWeights) + 0.05 * np.random.randn(N, 1) #simulate Y, with normal noise with variance 0.05 and mean 0 | |
numOfIterations = 1000 | |
learningRate = 0.1 | |
[W, costArray] = gradientDescent(X, Y, numOfIterations, learningRate) | |
plt.plot(np.linspace(0,100,numOfIterations),costArray) | |
plt.xlabel("iteration") | |
plt.ylabel("cost",rotation=0,fontsize = 10) | |
plt.show() | |
print("real weights are %s" %realWeights) | |
print("fitted weights are %s" %W) | |
domain = np.linspace(0, 1, 1000) | |
plt.plot(X[:,1],Y,'o',domain,W[1, 0]*domain + W[0, 0],'r') | |
plt.xlabel(r'$x$') | |
plt.ylabel(r'$y$',rotation=0) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment