Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanholbrook/5ff804433fa33e0b9b36e39beb6baebe to your computer and use it in GitHub Desktop.
Save ryanholbrook/5ff804433fa33e0b9b36e39beb6baebe to your computer and use it in GitHub Desktop.
Optax schedules figure
=
# +
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("seaborn-whitegrid")
plt.rc('figure', autolayout=True)
plt.rc('axes',
labelweight='bold',
labelsize='large',
titleweight='bold',
titlesize=16,
titlepad=10)
const = ConstantSchedule(value=0.1)
cosdecay = CosineDecaySchedule(
init_value=1e-1,
decay_steps=256,
alpha=1e-5,
)
# CycleSchedule
expdecay = ExponentialDecaySchedule(
init_value=1e-1,
transition_steps=32,
decay_rate=0.5,
transition_begin=32,
staircase=False,
end_value=5e-3,
)
expdecay_stair = ExponentialDecaySchedule(
init_value=1e-1,
transition_steps=32,
decay_rate=0.5,
staircase=True,
)
linear_oc = LinearOneCycleSchedule(
transition_steps=256,
peak_value=1e-1,
pct_start=0.3,
pct_final=0.85,
div_factor=25.0,
final_div_factor=10000.0,
)
cosine_oc = CosineOneCycleSchedule(
transition_steps=256,
peak_value=1e-1,
pct_start=0.3,
div_factor=25.0,
final_div_factor=10000.0,
)
linear = LinearSchedule(
init_value=1e-1,
end_value=1e-4,
transition_steps=256,
transition_begin=0,
)
piecewise_const = PiecewiseConstantSchedule(
init_value=1e-1,
boundaries_and_scales={
32: 0.8,
128: 0.5,
192: 0.9,
},
)
piecewise_interp_linear = PiecewiseInterpolateSchedule(
interpolate_type='linear',
init_value=1e-1,
boundaries_and_scales={
32: 1.5,
192: 0.5,
224: 0.3,
},
)
piecewise_interp_cosine = PiecewiseInterpolateSchedule(
interpolate_type='cosine',
init_value=1e-1,
boundaries_and_scales={
32: 1.5,
192: 0.5,
224: 0.3,
},
)
poly = PolynomialSchedule(
power=5,
init_value=1e-1,
end_value=1e-4,
transition_steps=256,
transition_begin=32,
)
warmup_cosine = WarmupCosineDecaySchedule(
init_value=1e-4,
peak_value=1e-1,
warmup_steps=32,
decay_steps=224,
end_value=1e-5,
)
warmup_exp = WarmupExponentialDecaySchedule(
init_value=1e-4,
peak_value=1e-1,
warmup_steps=32,
transition_steps=32,
decay_rate=0.5,
transition_begin=32,
staircase=False,
end_value=1e-5,
)
schedules = {
# 'Constant': const,
'Cosine Decay': cosdecay,
'Exponential Decay': expdecay,
'Stairstep Exponential Decay': expdecay_stair,
'Linear 1Cycle': linear_oc,
'Cosine 1Cycle': cosine_oc,
'Linear': linear,
'Piecewise Constant': piecewise_const,
'Interpolated Linear': piecewise_interp_linear,
'Interpolated Cosine': piecewise_interp_cosine,
'Polynomial': poly,
'Warmup Cosine Decay': warmup_cosine,
'Warmup Exponential Decay': warmup_exp,
}
# -
# +
grid = np.arange(0, 256)
fig, _ = plt.subplots(nrows=4, ncols=3, figsize=(12, 10))
for (name, fn), ax in zip(schedules.items(), fig.axes):
ax.plot(grid, [fn(s) for s in grid])
ax.set_title(name)
ax.set_xticks([])
ax.set_yticks([])
# -
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment