Skip to content

Instantly share code, notes, and snippets.

@Flova
Flova / plot_grad_flow.py
Last active May 20, 2024 13:02
Plot the gradient flow (PyTorch)
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
def plot_grad_flow(named_parameters):
'''Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
ave_grads = []
max_grads= []