Last active
March 30, 2019 03:37
-
-
Save bdhammel/fb5c40c89d741fa9843374dc1f5ec426 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
BASE_LR = 1e-7 | |
EPOCHS = 500 | |
TRIALS = 3 | |
class LRscanner: | |
def __init__(self, method, factor): | |
self.method = method | |
self.factor = factor | |
def __call__(self, lr): | |
if self.method == 'linear': | |
lr += self.factor | |
elif self.method == 'exp': | |
lr *= self.factor | |
return lr | |
class Loss: | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __call__(self, w): | |
return np.mean((self.y - w * self.x)**2) | |
# You can uncomment one-or-the-other to try different learning rate searchers | |
increase_lr = LRscanner(method='linear', factor=5e-4) | |
# increase_lr = LRscanner(method='exp', factor=1.1) | |
if __name__ == "__main__": | |
plt.ion() | |
plt.close('all') | |
# Equation to find. We want to find w, such that y = y_ | |
# y_ = wx | |
x = np.linspace(0, 10, 100) | |
y = 2.5 * x | |
plt.figure() | |
plt.plot(x, y, 'o') | |
plt.xlabel("x") | |
plt.ylabel("y") | |
# Plot the loss, the mean squared error: (y-y_)**2 | |
# for various weights, w | |
w = np.linspace(-5, 10, 200) | |
loss = np.mean((y - np.outer(w, x))**2, axis=1) | |
plt.figure() | |
plt.plot(w, loss) | |
plt.xlabel("weight value") | |
plt.ylabel("loss") | |
# Find the best learning rate to use with gradient decent to find the | |
# optimal value w | |
# w <- w - lr * dl / dw | |
loss_fn = Loss(x, y) | |
plt.figure() | |
for trial in range(TRIALS): | |
lr = BASE_LR | |
# Randomly select a starting value for w | |
# compute the loss associated with this weight | |
w_ = np.random.uniform(w.min(), w.max()) | |
loss_ = loss_fn(w_) | |
weights = [w_] | |
losses = [loss_] | |
lrs = [lr] | |
dldw = 1 | |
# For each epoch, update the weight using gradient decent, and increase | |
# the learning rate | |
for _ in range(EPOCHS): | |
# Update weight and calculate new loss (error) | |
w_ -= lr * dldw | |
loss_ = loss_fn(w_) | |
# Take gradient | |
dldw = (loss_ - losses[-1]) / (w_ - weights[-1]) | |
# Store the values | |
lrs.append(lr) | |
losses.append(loss_) | |
weights.append(w_) | |
# Increase the learning rate | |
lr = increase_lr(lr) | |
# Once we start to diverge go ahead and break out of the loop, there's | |
# no coming back | |
if loss_ > 1e7: | |
break | |
plt.plot(lrs, losses, label=f"trial: {trial}") | |
plt.xlabel("learning rate") | |
plt.ylabel("loss") | |
plt.xscale('log') | |
plt.yscale('log') | |
plt.legend() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment