Created
May 1, 2021 12:43
-
-
Save ryanholbrook/5ff804433fa33e0b9b36e39beb6baebe to your computer and use it in GitHub Desktop.
Optax schedules figure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
= | |
# + | |
import numpy as np | |
import matplotlib.pyplot as plt | |
plt.style.use("seaborn-whitegrid") | |
plt.rc('figure', autolayout=True) | |
plt.rc('axes', | |
labelweight='bold', | |
labelsize='large', | |
titleweight='bold', | |
titlesize=16, | |
titlepad=10) | |
const = ConstantSchedule(value=0.1) | |
cosdecay = CosineDecaySchedule( | |
init_value=1e-1, | |
decay_steps=256, | |
alpha=1e-5, | |
) | |
# CycleSchedule | |
expdecay = ExponentialDecaySchedule( | |
init_value=1e-1, | |
transition_steps=32, | |
decay_rate=0.5, | |
transition_begin=32, | |
staircase=False, | |
end_value=5e-3, | |
) | |
expdecay_stair = ExponentialDecaySchedule( | |
init_value=1e-1, | |
transition_steps=32, | |
decay_rate=0.5, | |
staircase=True, | |
) | |
linear_oc = LinearOneCycleSchedule( | |
transition_steps=256, | |
peak_value=1e-1, | |
pct_start=0.3, | |
pct_final=0.85, | |
div_factor=25.0, | |
final_div_factor=10000.0, | |
) | |
cosine_oc = CosineOneCycleSchedule( | |
transition_steps=256, | |
peak_value=1e-1, | |
pct_start=0.3, | |
div_factor=25.0, | |
final_div_factor=10000.0, | |
) | |
linear = LinearSchedule( | |
init_value=1e-1, | |
end_value=1e-4, | |
transition_steps=256, | |
transition_begin=0, | |
) | |
piecewise_const = PiecewiseConstantSchedule( | |
init_value=1e-1, | |
boundaries_and_scales={ | |
32: 0.8, | |
128: 0.5, | |
192: 0.9, | |
}, | |
) | |
piecewise_interp_linear = PiecewiseInterpolateSchedule( | |
interpolate_type='linear', | |
init_value=1e-1, | |
boundaries_and_scales={ | |
32: 1.5, | |
192: 0.5, | |
224: 0.3, | |
}, | |
) | |
piecewise_interp_cosine = PiecewiseInterpolateSchedule( | |
interpolate_type='cosine', | |
init_value=1e-1, | |
boundaries_and_scales={ | |
32: 1.5, | |
192: 0.5, | |
224: 0.3, | |
}, | |
) | |
poly = PolynomialSchedule( | |
power=5, | |
init_value=1e-1, | |
end_value=1e-4, | |
transition_steps=256, | |
transition_begin=32, | |
) | |
warmup_cosine = WarmupCosineDecaySchedule( | |
init_value=1e-4, | |
peak_value=1e-1, | |
warmup_steps=32, | |
decay_steps=224, | |
end_value=1e-5, | |
) | |
warmup_exp = WarmupExponentialDecaySchedule( | |
init_value=1e-4, | |
peak_value=1e-1, | |
warmup_steps=32, | |
transition_steps=32, | |
decay_rate=0.5, | |
transition_begin=32, | |
staircase=False, | |
end_value=1e-5, | |
) | |
schedules = { | |
# 'Constant': const, | |
'Cosine Decay': cosdecay, | |
'Exponential Decay': expdecay, | |
'Stairstep Exponential Decay': expdecay_stair, | |
'Linear 1Cycle': linear_oc, | |
'Cosine 1Cycle': cosine_oc, | |
'Linear': linear, | |
'Piecewise Constant': piecewise_const, | |
'Interpolated Linear': piecewise_interp_linear, | |
'Interpolated Cosine': piecewise_interp_cosine, | |
'Polynomial': poly, | |
'Warmup Cosine Decay': warmup_cosine, | |
'Warmup Exponential Decay': warmup_exp, | |
} | |
# - | |
# + | |
grid = np.arange(0, 256) | |
fig, _ = plt.subplots(nrows=4, ncols=3, figsize=(12, 10)) | |
for (name, fn), ax in zip(schedules.items(), fig.axes): | |
ax.plot(grid, [fn(s) for s in grid]) | |
ax.set_title(name) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
# - |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment