Skip to content

Instantly share code, notes, and snippets.

@tejaskhot
Last active April 19, 2022 20:59
Show Gist options
  • Save tejaskhot/2bbc4f15ba7bde33da9aa3a9dcb5c3e0 to your computer and use it in GitHub Desktop.
Save tejaskhot/2bbc4f15ba7bde33da9aa3a9dcb5c3e0 to your computer and use it in GitHub Desktop.
exponential learning rate decay in pytorch
def exp_lr_scheduler(optimizer, global_step, init_lr, decay_steps, decay_rate, lr_clip, staircase=True):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
if staircase:
lr = init_lr * decay_rate**(global_step // decay_steps)
else:
lr = init_lr * decay_rate**(global_step / decay_steps)
lr = max(lr, lr_clip)
if global_step % decay_steps == 0:
print('LR is set to {}'.format(lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment