Skip to content

Instantly share code, notes, and snippets.

@finlytics-hub
Last active October 24, 2020 17:37
Show Gist options
  • Save finlytics-hub/74346b9218ffb82a1f4a932e41f6d199 to your computer and use it in GitHub Desktop.
Save finlytics-hub/74346b9218ffb82a1f4a932e41f6d199 to your computer and use it in GitHub Desktop.
Linear Regression - Stochastic GD
'''
NORMAL EQUATION
'''
# generate some random numbers (independent variable)
X = 5 * np.random.rand(500,1)
# calculate linearly related (plus some noise) target variable in the form y = 10 + 2x + noise
y = 10 + 2 * X + np.random.randn(500,1)
# add ones to X for each observation (X0)
X_2d = np.c_[np.ones((500, 1)), X]
# calculate theta that minimizes MSE through Normal Equation
theta_best = np.linalg.inv(X_2d.T.dot(X_2d)).dot(X_2d.T).dot(y)
print('Normal Equation:\n', theta_best)
'''
BATCH GRADIENT DESCENT
'''
# set the learning rate
eta = 0.1
# max number of iterations for GD to try and converge
n_iterations = 1000
# number of observations in the training data, i.e. X
m = 500
# random initialization
theta = np.random.randn(2, 1)
# perform iterations
for iteration in range(n_iterations):
gradients = 2/m * X_2d.T.dot(X_2d.dot(theta) - y)
theta = theta - eta * gradients
# print theta values
print('Batch GD:\n', theta)
'''
Stochastic GD
'''
# max number of iterations for Stochastic GD to try and converge
n_epochs = 50
# learning schedule hyperparameters
t0, t1 = 5, 50
# define function for the learning schedule
def learning_schedule(t):
return t0 / (t + t1)
# random initialization
theta = np.random.randn(2, 1)
# perform iterations
for epoch in range(n_epochs):
for i in range(m):
random_index = np.random.randint(m)
# pick up X and y values at random - random selection is critical for stochastic GD
x_random = X_2d[random_index:random_index + 1]
y_random = y[random_index:random_index + 1]
gradients = 2 * x_random.T.dot(x_random.dot(theta) - y_random)
eta = learning_schedule(epoch * m + i)
theta = theta - eta * gradients
# print theta values
print('Stochastic GD:\n', theta)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment