Skip to content

Instantly share code, notes, and snippets.

@fabclmnt
Created May 6, 2020 14:26
Show Gist options
  • Save fabclmnt/1476354b609fe15e74945ba4ee2cb5e9 to your computer and use it in GitHub Desktop.
Save fabclmnt/1476354b609fe15e74945ba4ee2cb5e9 to your computer and use it in GitHub Desktop.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Define the callbacks
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.lr.numpy()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment