Created
September 10, 2021 09:52
-
-
Save jseppanen/d220f5d0ba0b5f1b8859f8595a195827 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim.lr_scheduler import LambdaLR | |
class SanderLR(LambdaLR): | |
"""Linear warmup -> constant -> 10x decay. | |
Badly tuned LR decay schedules are an excellent way to silently shoot yourself | |
in the foot. Models can often look like they are converging but it's just LR | |
getting too low too fast. FixedLR (+optional warmup) with 1 manual decay of 10X | |
on plateau is a safe strong baseline. | |
https://twitter.com/karpathy/status/1431380525759885313?s=20 | |
My go-to is Adam with LR 2e-4 (not 3e-4, I guess I'm old-fashioned) with two 10x | |
decreases ~80% and ~95% of the way through. I find myself deviating from this very | |
rarely. Usually there are plenty of other things to try that have a higher pay-off | |
than tuning the learning rate. | |
https://twitter.com/sedielem/status/1431382567299821569?s=20 | |
""" | |
def __init__(self, optimizer, *, warmup_steps: int = 100, total_steps: int = 0): | |
if total_steps <= 0: | |
raise ValueError("please provide total training length in total_steps") | |
first_decay_steps = int(0.8 * total_steps) | |
second_decay_steps = int(0.95 * total_steps) | |
if warmup_steps >= first_decay_steps: | |
raise ValueError(f"warmup is too long: {warmup_steps} >= {first_decay_steps}") | |
def lr_lambda(current_step: int) -> float: | |
if current_step > total_steps: | |
warnings.warn(f"learning rate scheduler total_steps is too short: {current_step} > {total_steps}") | |
warmup_coef = min(1.0, current_step / warmup_steps) | |
decay_coef = ( | |
1.0 if current_step < first_decay_steps else ( | |
0.1 if current_step < second_decay_steps else 0.01 | |
)) | |
return warmup_coef * decay_coef | |
super().__init__(optimizer, lr_lambda) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment