Skip to content

Instantly share code, notes, and snippets.

@bdhammel
Last active March 30, 2019 03:37
Show Gist options
  • Save bdhammel/fb5c40c89d741fa9843374dc1f5ec426 to your computer and use it in GitHub Desktop.
Save bdhammel/fb5c40c89d741fa9843374dc1f5ec426 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
BASE_LR = 1e-7
EPOCHS = 500
TRIALS = 3
class LRscanner:
def __init__(self, method, factor):
self.method = method
self.factor = factor
def __call__(self, lr):
if self.method == 'linear':
lr += self.factor
elif self.method == 'exp':
lr *= self.factor
return lr
class Loss:
def __init__(self, x, y):
self.x = x
self.y = y
def __call__(self, w):
return np.mean((self.y - w * self.x)**2)
# You can uncomment one-or-the-other to try different learning rate searchers
increase_lr = LRscanner(method='linear', factor=5e-4)
# increase_lr = LRscanner(method='exp', factor=1.1)
if __name__ == "__main__":
plt.ion()
plt.close('all')
# Equation to find. We want to find w, such that y = y_
# y_ = wx
x = np.linspace(0, 10, 100)
y = 2.5 * x
plt.figure()
plt.plot(x, y, 'o')
plt.xlabel("x")
plt.ylabel("y")
# Plot the loss, the mean squared error: (y-y_)**2
# for various weights, w
w = np.linspace(-5, 10, 200)
loss = np.mean((y - np.outer(w, x))**2, axis=1)
plt.figure()
plt.plot(w, loss)
plt.xlabel("weight value")
plt.ylabel("loss")
# Find the best learning rate to use with gradient decent to find the
# optimal value w
# w <- w - lr * dl / dw
loss_fn = Loss(x, y)
plt.figure()
for trial in range(TRIALS):
lr = BASE_LR
# Randomly select a starting value for w
# compute the loss associated with this weight
w_ = np.random.uniform(w.min(), w.max())
loss_ = loss_fn(w_)
weights = [w_]
losses = [loss_]
lrs = [lr]
dldw = 1
# For each epoch, update the weight using gradient decent, and increase
# the learning rate
for _ in range(EPOCHS):
# Update weight and calculate new loss (error)
w_ -= lr * dldw
loss_ = loss_fn(w_)
# Take gradient
dldw = (loss_ - losses[-1]) / (w_ - weights[-1])
# Store the values
lrs.append(lr)
losses.append(loss_)
weights.append(w_)
# Increase the learning rate
lr = increase_lr(lr)
# Once we start to diverge go ahead and break out of the loop, there's
# no coming back
if loss_ > 1e7:
break
plt.plot(lrs, losses, label=f"trial: {trial}")
plt.xlabel("learning rate")
plt.ylabel("loss")
plt.xscale('log')
plt.yscale('log')
plt.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment