Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
class Schedule(LearningRateSchedule):
def __init__(self, num_neurons, warmup_steps=4000):
super(Schedule, self).__init__()
self.num_neurons = tf.cast(num_neurons, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.num_neurons) * tf.math.minimum(arg1, arg2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment