Skip to content

Instantly share code, notes, and snippets.

@Mirodil
Last active December 14, 2018 17:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mirodil/5340ac9950df0f3d52522f3ecc481aac to your computer and use it in GitHub Desktop.
Save Mirodil/5340ac9950df0f3d52522f3ecc481aac to your computer and use it in GitHub Desktop.
LearningRateFinder for keras
class LearningRateFinder(Callback):
'''
This callback implements a learning rate finder(LRF)
The learning rate is constantly increased during training.
On training end, the training loss is plotted against the learning rate.
One may choose a learning rate for a model based on the given graph,
selecting a value slightly before the minimal training loss.
# Example
lrf = LearningRateFinder([0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05])
model.fit(x_train, y_train, epochs=1, batch_size=128, callbacks=[lrf])
# Arguments
lrs: list of learning rates
'''
def __init__(self, lrs):
self.index = 0
self.learningRateList = lrs
self.losses = []
self.lrs = []
def on_epoch_end(self, epoch, logs={}):
lr = float(K.get_value(self.model.optimizer.lr))
self.losses.append(logs.get('loss'))
self.lrs.append(lr)
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = float(K.get_value(self.model.optimizer.lr))
try:
lr = self.learningRateList[self.index]
self.model.set_weights(self.initial_weights)
self.index = self.index + 1
if(self.index>=len(self.learningRateList)):
self.index = 0
except TypeError: # old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
print('\nEpoch %05d: LearningRateFinder changing learning rate to %s.' % (epoch + 1, lr))
def on_train_end(self, logs=None):
plt.plot(self.lrs, self.losses)
plt.ylabel('losses')
plt.xlabel('lrs')
plt.show()
def on_train_begin(self, logs=None):
self.initial_weights = self.model.get_weights()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment