Skip to content

Instantly share code, notes, and snippets.

@stengoes
Created December 21, 2021 07:43
Show Gist options
  • Save stengoes/b94417c2d5e161507dbd61a4c8c8e211 to your computer and use it in GitHub Desktop.
Save stengoes/b94417c2d5e161507dbd61a4c8c8e211 to your computer and use it in GitHub Desktop.
CosineDecayWithWarmup
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(
self,
warmup_steps,
total_steps,
base_lr=0.001,
):
super(CosineDecayWithWarmup, self).__init__()
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.base_lr = base_lr
return
def __call__(
self,
step
):
warmup_lr = self.base_lr * tf.cast((step / self.warmup_steps), tf.float32)
lr = self.base_lr * (tf.cos(math.pi * (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)) + 1.0) / 2.0
return tf.where(step < self.warmup_steps, warmup_lr, lr)
def get_config(self):
return {
"warmup_steps": self.warmup_steps,
"total_steps": self.total_steps,
"base_lr": self.base_lr
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment