Created
February 2, 2021 12:44
-
-
Save snakers4/f64b442f0d0bcbf800acff1b8a8a486d to your computer and use it in GitHub Desktop.
Gradient Adaptive Factor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import itertools | |
# Based on https://github.com/bfs18/tacotron2 | |
def grads_for_params(loss, parameters, optimizer): | |
optimizer.zero_grad() | |
loss.backward(retain_graph=True) | |
grads = [] | |
for p in parameters: | |
if p.grad is not None: | |
grads.append(p.grad.detach().clone()) | |
optimizer.zero_grad() | |
return grads | |
def calc_grad_norm(grads, method='max'): | |
if method == 'max': | |
return torch.stack([torch.max(torch.abs(g)) for g in grads]) | |
elif method == 'l1': | |
return torch.stack([torch.sum(torch.abs(g)) for g in grads]) | |
else: | |
raise ValueError('Unsupported method [{}]'.format(method)) | |
def calc_grad_adapt_factor(loss1, loss2, parameters, optimizer): | |
# return a factor for loss2 to make the greatest gradients | |
# for loss1 and loss2 in similar scale. | |
parameters, parameters_backup = itertools.tee(parameters) | |
grads1 = grads_for_params(loss1, parameters, optimizer) | |
grads2 = grads_for_params(loss2, parameters_backup, optimizer) | |
norms1 = calc_grad_norm(grads1) | |
norms2 = calc_grad_norm(grads2) | |
indices = (norms1 != 0) & (norms2 != 0) | |
norms1 = norms1[indices] | |
norms2 = norms2[indices] | |
# return torch.mean(norms1 / norms2) | |
return torch.min(norms1 / norms2) | |
def calc_gaf(model, optimizer, loss1, loss2, max_gaf): | |
safe_loss = 0. * sum([x.sum() for x in model.parameters()]) | |
gaf = calc_grad_adapt_factor( | |
loss1 + safe_loss, loss2 + safe_loss, model.parameters(), optimizer) | |
gaf = min(gaf, max_gaf) | |
return gaf |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment