Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created May 22, 2024 13:40
Show Gist options
  • Save kohya-ss/1876e6c19bad80d5ced360011fbdf183 to your computer and use it in GitHub Desktop.
Save kohya-ss/1876e6c19bad80d5ced360011fbdf183 to your computer and use it in GitHub Desktop.
指定ステップまで定数、そこから減少、その後定数
# 仮に logs に入れたら以下のように指定
# --lr_scheduler_type logs.stepwise_linear_decay_lr_scheduler.get_stepwise_linear_decay_lr_scheduler
# --lr_scheduler_args "step_a=50" "step_b=80" "factor_1=1.0" "factor_2=0.1"
#
# step_a までの学習率: 指定した learning_rate * factor_1
# step_a から step_b まで: 線形に減少(増加)
# step_b からの学習率: 指定した learning_rate * factor_2
from torch.optim.lr_scheduler import LambdaLR
def get_stepwise_linear_decay_lr_scheduler(optimizer, step_a, step_b, factor_1, factor_2):
def lr_lambda(epoch):
if epoch < step_a:
return factor_1
elif epoch < step_b:
scale = (step_b - epoch) / (step_b - step_a)
factor = factor_1 * scale + factor_2 * (1 - scale)
return factor
else:
return factor_2
return LambdaLR(optimizer, lr_lambda)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment