Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
LearningRateFinder for keras
class LearningRateFinder(Callback):
This callback implements a learning rate finder(LRF)
The learning rate is constantly increased during training.
On training end, the training loss is plotted against the learning rate.
One may choose a learning rate for a model based on the given graph,
selecting a value slightly before the minimal training loss.
# Example
lrf = LearningRateFinder([0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05]), y_train, epochs=1, batch_size=128, callbacks=[lrf])
# Arguments
lrs: list of learning rates
def __init__(self, lrs):
self.index = 0
self.learningRateList = lrs
self.losses = []
self.lrs = []
def on_epoch_end(self, epoch, logs={}):
lr = float(K.get_value(
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = float(K.get_value(
lr = self.learningRateList[self.index]
self.index = self.index + 1
self.index = 0
except TypeError: # old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(, lr)
print('\nEpoch %05d: LearningRateFinder changing learning rate to %s.' % (epoch + 1, lr))
def on_train_end(self, logs=None):
plt.plot(self.lrs, self.losses)
def on_train_begin(self, logs=None):
self.initial_weights = self.model.get_weights()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment