Skip to content

Instantly share code, notes, and snippets.

@1pha
Created January 13, 2021 10:07
Show Gist options
  • Save 1pha/f215f0edd0ed88a68ae6495bcc3ebeb1 to your computer and use it in GitHub Desktop.
Save 1pha/f215f0edd0ed88a68ae6495bcc3ebeb1 to your computer and use it in GitHub Desktop.
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
def get_std_opt(model):
return NoamOpt(model.src_embed[0].d_model, 2, 4000,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment