Skip to content

Instantly share code, notes, and snippets.

@maskaravivek
Last active June 22, 2020 06:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maskaravivek/40c3fe20ea24238067bf6b79b505e7d3 to your computer and use it in GitHub Desktop.
Save maskaravivek/40c3fe20ea24238067bf6b79b505e7d3 to your computer and use it in GitHub Desktop.
from tensorflow import keras
class CustomLearningRateScheduler(keras.callbacks.Callback):
def __init__(self, schedule):
super(CustomLearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment