Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 13, 2019 10:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vadimkantorov/719f24443b21f695f766815edf126743 to your computer and use it in GitHub Desktop.
Save vadimkantorov/719f24443b21f695f766815edf126743 to your computer and use it in GitHub Desktop.
LARC gradient clipping in PyTorch
# ported from https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/optimizers/optimizers.py
# paper: https://arxiv.org/abs/1708.03888
# more advanced PyTorch variant: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
# Usage: larc_(optimizer.param_groups, larc_eta = 1e-3)
import torch
def larc_(param_groups, larc_eta = 1e-3, larc_mode = 'clip', min_update = 1e-7, eps = 1e-7):
for group in param_groups:
for p in filter(lambda p: p.grad is not None, group['params']):
v_norm = p.data.norm()
g_norm = p.grad.data.norm()
if larc_mode == 'clip':
larc_grad_update = torch.clamp(larc_eta * v_norm / (group['lr'] * (g_norm + eps)), min = min_update, max = 1)
else:
larc_grad_update = torch.clamp(larc_eta * v_norm / (g_norm + eps), min = min_update)
p.grad.data.mul_(larc_grad_update)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment