Skip to content

Instantly share code, notes, and snippets.

@Flova
Last active March 13, 2024 04:52
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save Flova/8bed128b41a74142a661883af9e51490 to your computer and use it in GitHub Desktop.
Save Flova/8bed128b41a74142a661883af9e51490 to your computer and use it in GitHub Desktop.
Plot the gradient flow (PyTorch)
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
def plot_grad_flow(named_parameters):
'''Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
ave_grads = []
max_grads= []
layers = []
for n, p in named_parameters:
if(p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean().item())
max_grads.append(p.grad.abs().max().item())
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
plt.xlabel("Layers")
plt.ylabel("average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend([Line2D([0], [0], color="c", lw=4),
Line2D([0], [0], color="b", lw=4),
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
@Crystal-Spider
Copy link

Thank you for this gist!
A little side note: when trying this function I got a conversion error from Tensor to numpy, this was due because my model was running on CUDA and doing p.grad.abs().mean() yields a Tensor (samething goes with max()). To fix this, it's enough to add .item() afterwards, like p.grad.abs().mean().item().

@Flova
Copy link
Author

Flova commented Jul 13, 2023

Thanks, I changed it.

@rschiewer
Copy link

What an amazing little piece of code, many thanks! One suggestion to make handling varying gradient magnitudes better: Why not make the y scale logarithmic?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment