Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active March 2, 2019 08:30
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/168b2167ac18c0671f41ae9f6fb86e66 to your computer and use it in GitHub Desktop.
Save wassname/168b2167ac18c0671f41ae9f6fb86e66 to your computer and use it in GitHub Desktop.
A learning rate scheduler for pytorch which interpolates on log or linear scales
from torch.optim.lr_scheduler import _LRScheduler
import numpy as np
class InterpolatingScheduler(_LRScheduler):
def __init__(self, optimizer, steps, lrs, scale='log', last_epoch=-1):
"""A scheduler that interpolates given values
Args:
- optimizer: pytorch optimizer
- steps: list or array with the x coordinates of the interpolated values
- lrs: list or array with the learning rates corresponding to the steps
- scale: one of ['linear', 'log'] the scale on which to interpolate. Log is usefull since learning rates operate on a logarithmic scale.
Usage:
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingScheduler(optimizer, steps=[0, 100, 400], lrs=[1e-6, 1e-4, 1e-8], scale='log')
"""
self.scale = scale
self.steps = steps
self.lrs = lrs
super().__init__(optimizer, last_epoch)
def get_lr(self):
x = [self.last_epoch]
if self.scale=='linear':
y = np.interp(x, self.steps, self.lrs)
elif self.scale=='log':
y = np.interp(x, self.steps, np.log(self.lrs))
y = np.exp(y)
else:
raise ValueError("scale should be one of ['linear', 'log']")
return [y[0] for lr in self.base_lrs]
# Example of use
import torch.optim as optim
from torch import nn
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingScheduler(optimizer, steps=[0, 100, 400, 800], lrs=[1e-6, 1e-4, 1e-8, 1e-9], scale='log')
# plot the lr schedule
x=np.linspace(0, 1000, 6000)
y=[]
for xx in x:
lr_scheduler.last_epoch=xx
lry = lr_scheduler.get_lr()[0]
y.append(lry)
lr_scheduler.last_epoch=-1
plt.figure()
plt.plot(x,y)
plt.title('InterpolatingScheduler')
plt.yscale('log')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.show()
@wassname
Copy link
Author

download

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment