Skip to content

Instantly share code, notes, and snippets.

@matttrent
Last active October 25, 2019 07:04
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matttrent/87a5261d0a5c59f2b7f16d8f1fc3ef0a to your computer and use it in GitHub Desktop.
Save matttrent/87a5261d0a5c59f2b7f16d8f1fc3ef0a to your computer and use it in GitHub Desktop.
Keras list of learning rates schedule
def list_rate_schedule(lrates, output=True):
sched = []
last_lr = [0]
for lr, n in lrates:
sched += [lr] * n
def lr_sched(epoch):
lr = sched[-1]
if epoch < len(sched):
lr = sched[epoch]
if output and lr != last_lr[0]:
print('Learning rate: {}'.format(lr))
last_lr[0] = lr
return lr
return lr_sched
# usage ----------------------------------------------------
rates = [
(1e-5, 2), # 2 epochs @ 1e-5
(1e-3, 4), # 4 epochs @ 1e-3
(1e-4, 8), # 8 epochs @ 1e-4
# ...
]
lrsched = keras.callbacks.LearningRateScheduler(
list_rate_schedule(rates))
model.fit_generator(
trn_batches,
trn_batches.samples // trn_batches.batch_size,
callacks=[lrsched],
# ...
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment