Skip to content

Instantly share code, notes, and snippets.

@ryanholbrook
Created May 1, 2021 12:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanholbrook/fe3984d5a1c4fc7c36de3b02536ea866 to your computer and use it in GitHub Desktop.
Save ryanholbrook/fe3984d5a1c4fc7c36de3b02536ea866 to your computer and use it in GitHub Desktop.
Keras Optax Schedules
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from absl import logging
from typing import Callable, Dict, Union, Optional, Iterable, Sequence
from tensorflow import keras
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
# Schedules ported from Optax
# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py
class ConstantSchedule(LearningRateSchedule):
"""Constructs a constant schedule.
Args:
value: value to be held constant throughout.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
value: Union[float, int],
):
self.value = value
def __call__(self, step):
return tf.constant(self.value, shape=tf.convert_to_tensor(step).shape)
class PolynomialSchedule(LearningRateSchedule):
"""Constructs a schedule with polynomial transition from init to end value.
Args:
init_value: initial value for the scalar to be annealed.
end_value: end value of the scalar to be annealed.
power: the power of the polynomial used to transition from init to end.
transition_steps: number of steps over which annealing takes place,
the scalar starts changing at `transition_begin` steps and completes
the transition by `transition_begin + transition_steps` steps.
If `transition_steps <= 0`, then the entire annealing process is disabled
and the value is held fixed at `init_value`.
transition_begin: must be positive. After how many steps to start annealing
(before this many steps the scalar value is held fixed at `init_value`).
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value,
end_value,
power,
transition_steps: int,
transition_begin: int = 0,
):
self.init_value = init_value
self.end_value = end_value
self.power = power
self.transition_steps = transition_steps
self.transition_begin = transition_begin
if self.transition_steps <= 0:
logging.info(
'A polynomial schedule was set with a non-positive `transition_steps` '
'value; this results in a constant schedule with value `init_value`.'
)
if transition_begin < 0:
logging.info(
'An polynomial schedule was set with a negative `transition_begin` '
'value; this will result in `transition_begin` falling back to `0`.'
)
self.transition_begin = 0
def __call__(self, step):
if self.transition_steps <= 0:
return self.init_value
step = tnp.clip(
step - self.transition_begin,
0,
self.transition_steps,
)
frac = 1 - step / self.transition_steps
return ((self.init_value - self.end_value) * (frac**self.power) +
self.end_value)
class LinearSchedule(PolynomialSchedule):
"""Constructs a `PolynomialSchedule` with `power=1`."""
def __init__(
self,
init_value: int,
end_value: int,
transition_steps: int,
transition_begin: int = 0,
):
super().__init__(
init_value=init_value,
end_value=end_value,
transition_steps=transition_steps,
transition_begin=transition_begin,
power=1,
)
class PiecewiseConstantSchedule(LearningRateSchedule):
"""Returns a function which implements a piecewise constant schedule.
Args:
init_value: An initial value `init_v`.
boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling
factors `f_i`. For any step count `s`, the schedule returns `init_v`
scaled by the product of all factors `f_i` such that `b_i` < `s`.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value: float,
boundaries_and_scales: Optional[Dict[int, float]] = None,
):
if boundaries_and_scales is not None:
all_positive = all(scale >= 0.0
for scale in boundaries_and_scales.values())
if not all_positive:
raise ValueError(
'`PiecewiseConstantSchedule` expects non-negative scale factors'
)
self.init_value = init_value
self.boundaries_and_scales = boundaries_and_scales
def __call__(self, step):
v = self.init_value
if self.boundaries_and_scales is not None:
for threshold, scale in sorted(self.boundaries_and_scales.items()):
indicator = tf.maximum(0., tnp.sign(threshold - step))
v = v * indicator + (1 - indicator) * scale * v
return v
class ExponentialDecaySchedule(LearningRateSchedule):
"""Constructs a schedule with either continuous or discrete exponential decay.
This function applies an exponential decay function to a provided initial
value. The function returns the decayed value as follows:
```
decayed_value = init_value * decay_rate ^ (count / transition_steps)
```
If the argument `staircase` is `True`, then `count / transition_steps` is
an integer division and the decayed value follows a staircase function.
Args:
init_value: the initial learning rate.
transition_steps: must be positive. See the decay computation above.
decay_rate: must not be zero. The decay rate.
transition_begin: must be positive. After how many steps to start annealing
(before this many steps the scalar value is held fixed at `init_value`).
staircase: if `True`, decay the values at discrete intervals.
end_value: the value at which the exponential decay stops. When
`decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as
an upper bound. Has no effect when `decay_rate` = 0.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value: float,
transition_steps: int,
decay_rate: float,
transition_begin: int = 0,
staircase: bool = False,
end_value: Optional[float] = None,
) -> LearningRateSchedule:
if transition_steps <= 0:
logging.info(
'An exponential schedule was set with a non-positive `transition_steps`'
' value; this will result in a constant schedule with value '
'`init_value`.')
if decay_rate == 0:
logging.info(
'An exponential schedule was set with a zero `decay_rate` value; '
'this will result in a constant schedule with value `init_value`.'
)
if transition_begin < 0:
logging.info(
'An exponential schedule was set with a negative `transition_begin` '
'value; this will result in `transition_begin` falling back to `0`.'
)
self.transition_begin = 0
if end_value is not None:
self.clip_fn = tnp.maximum if decay_rate < 1.0 else tnp.minimum
self.init_value = init_value
self.transition_steps = transition_steps
self.decay_rate = decay_rate
self.transition_begin = transition_begin
self.staircase = staircase
self.end_value = end_value
def __call__(self, step):
decreased_count = step - self.transition_begin
p = decreased_count / self.transition_steps
if self.staircase:
p = tnp.floor(p)
decayed_value = tnp.where(
decreased_count <= 0,
self.init_value,
self.init_value * tnp.power(self.decay_rate, p),
)
if self.end_value is not None:
decayed_value = self.clip_fn(decayed_value, self.end_value)
return decayed_value
class CosineDecaySchedule(LearningRateSchedule):
"""Returns a function which implements cosine learning rate decay.
For more details see: https://arxiv.org/abs/1608.03983
Args:
init_value: An initial value `init_v`.
decay_steps: Positive integer - the number of steps for which to apply
the decay for.
alpha: Float. The minimum value of the multiplier used to adjust the
learning rate.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value: float,
decay_steps: int,
alpha: float = 0.0,
) -> LearningRateSchedule:
if not decay_steps > 0:
raise ValueError(
'The cosine_decay_schedule requires positive decay_steps!')
self.init_value = init_value
self.decay_steps = decay_steps
self.alpha = alpha
def __call__(self, step):
step = tnp.minimum(step, self.decay_steps)
cosine_decay = 0.5 * (1 + tnp.cos(tnp.pi * step / self.decay_steps))
decayed = (1 - self.alpha) * cosine_decay + self.alpha
return self.init_value * decayed
def _linear_interpolate(start: float, end: float, pct: float):
return (end - start) * pct + start
def _cosine_interpolate(start: float, end: float, pct: float):
return end + (start - end) / 2.0 * (tnp.cos(tnp.pi * pct) + 1)
class PiecewiseInterpolateSchedule(LearningRateSchedule):
"""Returns a function which implements a piecewise interpolated schedule.
Args:
interpolate_type: 'linear' or 'cosine', specifying the interpolation
strategy.
init_value: An initial value `init_v`.
boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling
factors `f_i`. At boundary step `b_i`, the schedule returns `init_v`
scaled by the product of all factors `f_j` such that `b_j` < `b_i`. The
values in between each boundary will be interpolated as per `type`.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
interpolate_type: str,
init_value: float,
boundaries_and_scales: Optional[Dict[int, float]] = None
) -> LearningRateSchedule:
self.interpolate_type = interpolate_type
self.init_value = init_value
self.boundaries_and_scales = boundaries_and_scales
if interpolate_type == 'linear':
self.interpolate_fn = _linear_interpolate
elif interpolate_type == 'cosine':
self.interpolate_fn = _cosine_interpolate
else:
raise ValueError(
'`interpolate_type` must be either \'cos\' or \'linear\'')
if boundaries_and_scales is not None:
self.boundaries, self.scales = zip(
*sorted(boundaries_and_scales.items()))
if not all(scale >= 0. for scale in self.scales):
raise ValueError(
'`piecewise_interpolate_schedule` expects non-negative scale factors'
)
else:
self.boundaries, self.scales = (), ()
self.bounds = tnp.stack((0, ) + self.boundaries)
self.values = tnp.cumprod(tnp.stack((self.init_value, ) + self.scales))
self.interval_sizes = (self.bounds[1:] - self.bounds[:-1])
def __call__(self, step):
indicator = (tf.cast(self.bounds[:-1] <= step, tf.int8) *
tf.cast(step < self.bounds[1:], tf.int8))
pct = (step - self.bounds[:-1]) / self.interval_sizes
interp_vals = self.interpolate_fn(
self.values[:-1],
self.values[1:],
pct,
)
return (tnp.dot(indicator, interp_vals) +
(self.bounds[-1] <= step) * self.values[-1])
class LinearOneCycleSchedule(PiecewiseInterpolateSchedule):
def __init__(
self,
transition_steps: int,
peak_value: float,
pct_start: float = 0.3,
pct_final: float = 0.85,
div_factor: float = 25.0,
final_div_factor: float = 1e4,
) -> LearningRateSchedule:
"""Returns a function which implements the onecycle learning rate schedule.
This function uses a linear annealing strategy.
For more details see: https://arxiv.org/abs/1708.07120
Args:
transition_steps: Number of steps over which annealing takes place.
peak_value: Maximum value attained by schedule at pct_start percent
of the cycle (in number of steps).
pct_start: The percentage of the cycle (in number of steps) spent
increasing the learning rate.
pct_final: The percentage of the cycle (in number of steps) spent
increasing to peak_value then decreasing back to init_value.
div_factor: Determines the initial value via init_value =
peak_value / div_factor
final_div_factor: Determines the final value via final_value =
init_value / final_div_factor
Returns:
schedule: A function that maps step counts to values.
"""
if transition_steps <= 0:
raise ValueError(
'A linear onecycle schedule was set with a non-positive '
'`transition_steps`')
super().__init__(
interpolate_type='linear',
init_value=peak_value / div_factor,
boundaries_and_scales={
int(pct_start * transition_steps): div_factor,
int(pct_final * transition_steps): 1. / div_factor,
transition_steps: 1. / final_div_factor
},
)
class CosineOneCycleSchedule(PiecewiseInterpolateSchedule):
"""Returns a function which implements the onecycle learning rate schedule.
This function uses a cosine annealing strategy.
For more details see: https://arxiv.org/abs/1708.07120
Args:
transition_steps: Number of steps over which annealing takes place.
peak_value: Maximum value attained by schedule at pct_start percent
of the cycle (in number of steps).
pct_start: The percentage of the cycle (in number of steps) spent
increasing the learning rate.
div_factor: Determines the initial value via init_value =
peak_value / div_factor
final_div_factor: Determines the final value via final_value =
init_value / final_div_factor
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
transition_steps: int,
peak_value: float,
pct_start: float = 0.3,
div_factor: float = 25.0,
final_div_factor: float = 1e4,
) -> LearningRateSchedule:
if transition_steps <= 0:
raise ValueError(
'A linear onecycle schedule was set with a non-positive '
'`transition_steps`')
super().__init__(
interpolate_type='cosine',
init_value=peak_value / div_factor,
boundaries_and_scales={
int(pct_start * transition_steps): div_factor,
int(transition_steps): 1. / (div_factor * final_div_factor)
},
)
class JoinedSchedule(LearningRateSchedule):
"""Sequentially apply multiple schedules.
Args:
schedules: A list of callables (expected to be optax schedules). Each
schedule will receive a step count indicating the number of steps since
the previous boundary transition.
boundaries: A list of integers (of length one less than schedules) that
indicate when to transition between schedules.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
schedules: Sequence[LearningRateSchedule],
boundaries: Sequence[int],
):
self.schedules = schedules
self.boundaries = boundaries
def __call__(self, step):
lr = self.schedules[0](step)
for boundary, schedule in zip(self.boundaries, self.schedules[1:]):
lr = tf.where(step < boundary, lr, schedule(step - boundary))
return lr
class WarmupCosineDecaySchedule(JoinedSchedule):
"""Linear warmup followed by cosine decay.
Args:
init_value: Initial value for the scalar to be annealed.
peak_value: Peak value for scalar to be annealed at end of warmup.
warmup_steps: Positive integer, the length of the linear warmup.
decay_steps: Positive integer, the total length of the schedule. Note that
this includes the warmup time, so the number of steps during which cosine
annealing is applied is `decay_steps - warmup_steps`.
end_value: End value of the scalar to be annealed.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value: float,
peak_value: float,
warmup_steps: int,
decay_steps: int,
end_value: float = 0.0,
) -> LearningRateSchedule:
schedules = [
LinearSchedule(init_value=init_value,
end_value=peak_value,
transition_steps=warmup_steps),
CosineDecaySchedule(init_value=peak_value,
decay_steps=decay_steps - warmup_steps,
alpha=end_value / peak_value)
]
super().__init__(
schedules=schedules,
boundaries=[warmup_steps],
)
class WarmupExponentialDecaySchedule(JoinedSchedule):
"""Linear warmup followed by exponential decay.
Args:
init_value: Initial value for the scalar to be annealed.
peak_value: Peak value for scalar to be annealed at end of warmup.
warmup_steps: Positive integer, the length of the linear warmup.
transition_steps: must be positive. See the decay computation above.
decay_rate: must not be zero. The decay rate.
transition_begin: must be positive. After how many steps to start annealing
(before this many steps the scalar value is held fixed at `init_value`).
staircase: if `True`, decay the values at discrete intervals.
end_value: the value at which the exponential decay stops. When
`decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as
an upper bound. Has no effect when `decay_rate` = 0.
Returns:
schedule: A function that maps step counts to values.
"""
def __init__(
self,
init_value: float,
peak_value: float,
warmup_steps: int,
transition_steps: int,
decay_rate: float,
transition_begin: int = 0,
staircase: bool = False,
end_value: Optional[float] = None,
) -> LearningRateSchedule:
schedules = [
LinearSchedule(init_value=init_value,
end_value=peak_value,
transition_steps=warmup_steps),
ExponentialDecaySchedule(init_value=peak_value,
transition_steps=transition_steps,
decay_rate=decay_rate,
transition_begin=transition_begin,
staircase=staircase,
end_value=end_value)
]
super().__init__(
schedules=schedules,
boundaries=[warmup_steps],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment