Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 25, 2019 08:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vadimkantorov/0390be9b27d968063e1ab7f5088b1a85 to your computer and use it in GitHub Desktop.
Save vadimkantorov/0390be9b27d968063e1ab7f5088b1a85 to your computer and use it in GitHub Desktop.
NovoGrad optimizer in PyTorch
# ported from https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/optimizers/novograd.py
# paper: https://arxiv.org/abs/1905.11286
# a recent NVidia's implementation in PyTorch: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/optimizers.py
import torch
class NovoGrad(torch.optim.Optimizer):
def __init__(self, params, lr=1.0, betas = (0.95, 0.98), eps=1e-8, weight_decay=0.0, dampening=False):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, dampening=dampening)
super(NovoGrad, self).__init__(params, defaults)
def step(self):
for group in self.param_groups:
for p in filter(lambda p: p.grad is not None, group['params']):
state = self.state[p]
g_2 = (p.grad.data ** 2).sum()
state['_grads_ema'] = g_2 if '_grads_ema' not in state else state['_grads_ema'] * group['betas'][1] + g_2 * (1. - group['betas'][1])
grad = p.grad.data / (state['_grads_ema'] + group['eps']).sqrt()
if group['weight_decay'] > 0:
grad.add_(group['weight_decay'], p.data)
if group['dampening']:
grad *= 1 - group['betas'][0]
state['momentum_buffer'] = state['momentum_buffer'].mul_(group['betas'][0]).add_(grad) if 'momentum_buffer' in state else grad
p.data.add_(-group['lr'], state['momentum_buffer'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment