Skip to content

Instantly share code, notes, and snippets.

@soulmachine
Created May 6, 2019 01:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soulmachine/98bb635ccf7caa62ffbaa6da07264e3b to your computer and use it in GitHub Desktop.
Save soulmachine/98bb635ccf7caa62ffbaa6da07264e3b to your computer and use it in GitHub Desktop.
LearningRateBatchScheduler copied from experimental/resnet50_keras/resnet50.py
# Copied from https://github.com/tensorflow/tpu/blob/master/models/experimental/resnet50_keras/resnet50.py#L117
from absl import logging
from tensorflow.keras import backend as K
BASE_LEARNING_RATE = 0.4
# Learning rate schedule
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
def learning_rate_schedule_wrapper(training_steps_per_epoch):
"""Wrapper around the learning rate schedule."""
def learning_rate_schedule(current_epoch, current_batch):
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
After 5 epochs we reach the base learning rate (scaled to account
for batch size).
After 30, 60 and 80 epochs the learning rate is divided by 10.
After 90 epochs training stops and the LR is set to 0. This ensures
that we train for exactly 90 epochs for reproducibility.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in current epoch, indexed from 0.
Returns:
Adjusted learning rate.
"""
epoch = current_epoch + float(current_batch) / training_steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return (BASE_LEARNING_RATE * warmup_lr_multiplier *
epoch / warmup_end_epoch)
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = BASE_LEARNING_RATE * mult
else:
break
return learning_rate
return learning_rate_schedule
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Args:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
self.epochs += 1
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
def on_batch_begin(self, batch, logs=None):
lr = self.schedule(self.epochs, batch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
K.set_value(self.model.optimizer.lr, lr)
self.prev_lr = lr
logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
'learning rate to %s.', self.epochs, batch, lr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment