Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active August 22, 2018 05:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/df8bc03e60f81ff081e1895aabe1f519 to your computer and use it in GitHub Desktop.
Save wassname/df8bc03e60f81ff081e1895aabe1f519 to your computer and use it in GitHub Desktop.
pytorch isfinite: like numpy.isfinite but for torch tensors
def isfinite(x):
"""
Quick pytorch test that there are no nan's or infs.
note: torch now has torch.isnan
url: https://gist.github.com/wassname/df8bc03e60f81ff081e1895aabe1f519
"""
not_inf = ((x + 1) != x)
not_nan = (x == x)
return not_inf & not_nan
import torch
import numpy as np
assert isfinite(1)
assert isfinite(0)
assert not isfinite(np.nan)
assert not isfinite(np.inf)
assert not isfinite(np.NINF)
assert isfinite(torch.tensor(1))
assert isfinite(torch.tensor(0))
assert not isfinite(torch.tensor(np.nan))
assert not isfinite(torch.tensor(np.inf))
assert not isfinite(torch.tensor(np.NINF))
isfinite(torch.tensor([np.log(-1.),1.,np.log(0)]))
@wassname
Copy link
Author

wassname commented Aug 22, 2018

Messy example for debugging nan and inf gradients:

loss.backward()
if not all([isfinite(p.grad).all() for p in net.parameters()]):
    isfinite_report = dict([
                                    (name, isfinite(param.grad).all().item()) 
                                       for name, param in net.named_parameters()
                                     ])
    print("Warning: skipping non finite gradients. Finite layers:", isfinite_report)
    continue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment