Skip to content

Instantly share code, notes, and snippets.

@selflein
Created April 15, 2020 19:56
Show Gist options
  • Save selflein/b35401b9cbd2728c8d2f011afc211c12 to your computer and use it in GitHub Desktop.
Save selflein/b35401b9cbd2728c8d2f011afc211c12 to your computer and use it in GitHub Desktop.
[Logger for gradient norms in PyTorch with Tensorboard] #pytorch
class GradNormLogger:
def __init__(self):
self.grad_norms = defaultdict(list)
def update(self, model: torch.nn.Module, norm_type: float = 2.):
total_norm = 0
for name, p in model.named_parameters():
if p.requires_grad:
try:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
norm = param_norm ** (1 / norm_type)
module_name = name.split('.')[0]
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
self.grad_norms[module_name].append(grad)
except Exception:
# this param had no grad
pass
total_norm = total_norm ** (1. / norm_type)
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
self.grad_norms['grad_norm_total'].append(grad)
def reset(self):
self.grad_norms = defaultdict(list)
def write(self, writer, global_step: int):
"""Write to gradient norms to Tensorboard.
Args:
writer: Tensorboard instance.
global_step: Global step parameter for Tensorboard writer.
"""
for module, grads in self.grad_norms.items():
writer.add_histogram(f'gradient_histograms/{module}', np.array(grads), global_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment