Last active
January 28, 2020 09:06
-
-
Save ankitmundada/72ba4926af62cc5aaa87e1e996feb8ec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import pyplot as plt | |
import math | |
from tensorflow.keras.callbacks import LambdaCallback | |
import tensorflow.keras.backend as K | |
import numpy as np | |
class LRFinder: | |
""" | |
Plots the change of the loss function of a | |
eras model when the learning rate is exponentially increasing. | |
See for details: | |
https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0 | |
""" | |
def __init__(self, model): | |
self.model = model | |
self.losses = [] | |
self.lrs = [] | |
self.best_loss = 1e9 | |
def on_batch_end(self, batch, logs): | |
# Log the learning rate | |
lr = K.get_value(self.model.optimizer.learning_rate) | |
self.lrs.append(lr) | |
# Log the loss | |
loss = logs['loss'] | |
self.losses.append(loss) | |
# Check whether the loss got too large or NaN | |
if batch > 5 and (math.isnan(loss) or loss > self.best_loss * 4): | |
self.model.stop_training = True | |
return | |
if loss < self.best_loss: | |
self.best_loss = loss | |
# Increase the learning rate for the next batch | |
lr *= self.lr_mult | |
K.set_value(self.model.optimizer.learning_rate, lr) | |
def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1): | |
# If x_train contains data for multiple inputs, use length of the first input. | |
# Assumption: the first element in the list is single input; NOT a list of inputs. | |
N = x_train[0].shape[0] if isinstance(x_train, list) else x_train.shape[0] | |
# Compute number of batches and LR multiplier | |
num_batches = epochs * N / batch_size | |
self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(num_batches)) | |
# Save weights into a file | |
self.model.save_weights('tmp.h5') | |
# Remember the original learning rate | |
original_lr = K.get_value(self.model.optimizer.learning_rate) | |
# Set the initial learning rate | |
K.set_value(self.model.optimizer.learning_rate, start_lr) | |
callback = LambdaCallback(on_batch_end=lambda batch, logs: self.on_batch_end(batch, logs)) | |
self.model.fit(x_train, y_train, | |
batch_size=batch_size, epochs=epochs, | |
callbacks=[callback]) | |
# Restore the weights to the state before model fitting | |
self.model.load_weights('tmp.h5') | |
# Restore the original learning rate | |
K.set_value(self.model.optimizer.learning_rate, original_lr) | |
def find_generator(self, generator, start_lr, end_lr, epochs=1, steps_per_epoch=None, **kw_fit): | |
if steps_per_epoch is None: | |
try: | |
steps_per_epoch = len(generator) | |
except (ValueError, NotImplementedError) as e: | |
raise e('`steps_per_epoch=None` is only valid for a' | |
' generator based on the ' | |
'`keras.utils.Sequence`' | |
' class. Please specify `steps_per_epoch` ' | |
'or use the `keras.utils.Sequence` class.') | |
self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(epochs * steps_per_epoch)) | |
# Save weights into a file | |
self.model.save_weights('tmp.h5') | |
# Remember the original learning rate | |
original_lr = K.get_value(self.model.optimizer.learning_rate) | |
# Set the initial learning rate | |
K.set_value(self.model.optimizer.learning_rate, start_lr) | |
callback = LambdaCallback(on_batch_end=lambda batch, | |
logs: self.on_batch_end(batch, logs)) | |
self.model.fit_generator(generator=generator, | |
epochs=epochs, | |
steps_per_epoch=steps_per_epoch, | |
callbacks=[callback], | |
**kw_fit) | |
# Restore the weights to the state before model fitting | |
self.model.load_weights('tmp.h5') | |
# Restore the original learning rate | |
K.set_value(self.model.optimizer.learning_rate, original_lr) | |
def plot_loss(self, n_skip_beginning=10, n_skip_end=5, x_scale='log'): | |
""" | |
Plots the loss. | |
Parameters: | |
n_skip_beginning - number of batches to skip on the left. | |
n_skip_end - number of batches to skip on the right. | |
""" | |
plt.ylabel("loss") | |
plt.xlabel("learning rate (log scale)") | |
plt.plot(self.lrs[n_skip_beginning:-n_skip_end], self.losses[n_skip_beginning:-n_skip_end]) | |
plt.xscale(x_scale) | |
plt.show() | |
def plot_loss_change(self, sma=1, n_skip_beginning=10, n_skip_end=5, y_lim=(-0.01, 0.01)): | |
""" | |
Plots rate of change of the loss function. | |
Parameters: | |
sma - number of batches for simple moving average to smooth out the curve. | |
n_skip_beginning - number of batches to skip on the left. | |
n_skip_end - number of batches to skip on the right. | |
y_lim - limits for the y axis. | |
""" | |
derivatives = self.get_derivatives(sma)[n_skip_beginning:-n_skip_end] | |
lrs = self.lrs[n_skip_beginning:-n_skip_end] | |
plt.ylabel("rate of loss change") | |
plt.xlabel("learning rate (log scale)") | |
plt.plot(lrs, derivatives) | |
plt.xscale('log') | |
plt.ylim(y_lim) | |
plt.show() | |
def get_derivatives(self, sma): | |
assert sma >= 1 | |
derivatives = [0] * sma | |
for i in range(sma, len(self.lrs)): | |
derivatives.append((self.losses[i] - self.losses[i - sma]) / sma) | |
return derivatives | |
def get_best_lr(self, sma, n_skip_beginning=10, n_skip_end=5): | |
derivatives = self.get_derivatives(sma) | |
best_der_idx = np.argmax(derivatives[n_skip_beginning:-n_skip_end])[0] | |
return self.lrs[n_skip_beginning:-n_skip_end][best_der_idx] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment