Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created February 2, 2021 12:44
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snakers4/f64b442f0d0bcbf800acff1b8a8a486d to your computer and use it in GitHub Desktop.
Save snakers4/f64b442f0d0bcbf800acff1b8a8a486d to your computer and use it in GitHub Desktop.
Gradient Adaptive Factor
import torch
import itertools
# Based on https://github.com/bfs18/tacotron2
def grads_for_params(loss, parameters, optimizer):
optimizer.zero_grad()
loss.backward(retain_graph=True)
grads = []
for p in parameters:
if p.grad is not None:
grads.append(p.grad.detach().clone())
optimizer.zero_grad()
return grads
def calc_grad_norm(grads, method='max'):
if method == 'max':
return torch.stack([torch.max(torch.abs(g)) for g in grads])
elif method == 'l1':
return torch.stack([torch.sum(torch.abs(g)) for g in grads])
else:
raise ValueError('Unsupported method [{}]'.format(method))
def calc_grad_adapt_factor(loss1, loss2, parameters, optimizer):
# return a factor for loss2 to make the greatest gradients
# for loss1 and loss2 in similar scale.
parameters, parameters_backup = itertools.tee(parameters)
grads1 = grads_for_params(loss1, parameters, optimizer)
grads2 = grads_for_params(loss2, parameters_backup, optimizer)
norms1 = calc_grad_norm(grads1)
norms2 = calc_grad_norm(grads2)
indices = (norms1 != 0) & (norms2 != 0)
norms1 = norms1[indices]
norms2 = norms2[indices]
# return torch.mean(norms1 / norms2)
return torch.min(norms1 / norms2)
def calc_gaf(model, optimizer, loss1, loss2, max_gaf):
safe_loss = 0. * sum([x.sum() for x in model.parameters()])
gaf = calc_grad_adapt_factor(
loss1 + safe_loss, loss2 + safe_loss, model.parameters(), optimizer)
gaf = min(gaf, max_gaf)
return gaf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment