Last active
August 26, 2021 18:36
-
-
Save zimonitrome/6cf32f3db075f9d1cb9b185a7d7fb726 to your computer and use it in GitHub Desktop.
Python generators that ramps up and ramps down for ML schedules (or other applications).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
top_y → | o_ | |
| / \_ | |
| / \_ | |
low_y → |o o | |
+-------------- | |
↑ ↑ | |
mid last | |
if mid == 0: only warmup | |
if mid == last: only decay | |
""" | |
import numpy as np | |
def linear_warmup_decay(low_y, top_y, mid, last): | |
counter = 0 | |
all = np.concatenate([ | |
np.linspace(low_y, top_y, mid, endpoint=False), # Warmup | |
np.linspace(top_y, low_y, last-mid+1, endpoint=True) # Decay | |
]) | |
while counter <= last: | |
yield all[counter] | |
counter += 1 | |
def cosine_warmup_decay(low_y, top_y, mid, last): | |
counter = 0 | |
while counter <= last: | |
# Warmup | |
if counter < mid: | |
progress = counter / mid | |
yield (top_y-low_y)*(0.5-0.5*np.cos(progress*np.pi)) + low_y | |
# Decay | |
else: | |
progress = (counter - mid) / (last - mid) | |
yield (top_y-low_y)*(0.5-0.5*np.cos(progress*np.pi - np.pi)) + low_y | |
counter += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment