Skip to content

Instantly share code, notes, and snippets.

@jseppanen
Created September 10, 2021 09:52
Show Gist options
  • Save jseppanen/d220f5d0ba0b5f1b8859f8595a195827 to your computer and use it in GitHub Desktop.
Save jseppanen/d220f5d0ba0b5f1b8859f8595a195827 to your computer and use it in GitHub Desktop.
from torch.optim.lr_scheduler import LambdaLR
class SanderLR(LambdaLR):
"""Linear warmup -> constant -> 10x decay.
Badly tuned LR decay schedules are an excellent way to silently shoot yourself
in the foot. Models can often look like they are converging but it's just LR
getting too low too fast. FixedLR (+optional warmup) with 1 manual decay of 10X
on plateau is a safe strong baseline.
https://twitter.com/karpathy/status/1431380525759885313?s=20
My go-to is Adam with LR 2e-4 (not 3e-4, I guess I'm old-fashioned) with two 10x
decreases ~80% and ~95% of the way through. I find myself deviating from this very
rarely. Usually there are plenty of other things to try that have a higher pay-off
than tuning the learning rate.
https://twitter.com/sedielem/status/1431382567299821569?s=20
"""
def __init__(self, optimizer, *, warmup_steps: int = 100, total_steps: int = 0):
if total_steps <= 0:
raise ValueError("please provide total training length in total_steps")
first_decay_steps = int(0.8 * total_steps)
second_decay_steps = int(0.95 * total_steps)
if warmup_steps >= first_decay_steps:
raise ValueError(f"warmup is too long: {warmup_steps} >= {first_decay_steps}")
def lr_lambda(current_step: int) -> float:
if current_step > total_steps:
warnings.warn(f"learning rate scheduler total_steps is too short: {current_step} > {total_steps}")
warmup_coef = min(1.0, current_step / warmup_steps)
decay_coef = (
1.0 if current_step < first_decay_steps else (
0.1 if current_step < second_decay_steps else 0.01
))
return warmup_coef * decay_coef
super().__init__(optimizer, lr_lambda)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment