Skip to content

Instantly share code, notes, and snippets.

@Pangoraw
Last active September 2, 2022 10:04
Show Gist options
  • Save Pangoraw/847c256203bf0736843af73164e8ff67 to your computer and use it in GitHub Desktop.
Save Pangoraw/847c256203bf0736843af73164e8ff67 to your computer and use it in GitHub Desktop.
A helper class similar to `torch.autograd.detect_anomaly()` but for the forward pass!
from typing import Tuple
import torch
from torch import Tensor, nn
def get_hook(mod_name: str, layer_name: str, check_inf: bool = True):
def hook(_module: nn.Module, input: Tuple, output: Tuple) -> None:
for i, t in enumerate(input):
if torch.any(t.isnan()):
raise Exception(
f"input #{i} to layer {layer_name} of module {mod_name} has NaNs ({t.isnan().sum()} NaNs / {t.numel()})"
)
if check_inf and torch.any(t.isinf()):
raise Exception(
f"input #{i} to layer {layer_name} of module {mod_name} has Infs ({t.isinf().sum()} Infs / {t.numel()})"
)
for i, t in enumerate(output):
if torch.any(t.isnan()):
raise Exception(
f"output #{i} to layer {layer_name} of module {mod_name} has NaNs ({t.isnan().sum()} NaNs / {t.numel()})"
)
if check_inf and torch.any(t.isinf()):
raise Exception(
f"output #{i} to layer {layer_name} of module {mod_name} has Infs ({t.isinf().sum()} Infs / {t.numel()})"
)
return hook
class forward_detect_anomaly:
def __init__(self, module: nn.Module, check_inf: bool = True) -> None:
self.handles = []
self.module = module
self.check_inf = check_inf
def __enter__(self):
mod_name = self.module.__class__.__name__
for name, submod in self.module.named_modules():
self.handles.append(submod.register_forward_hook(get_hook(mod_name, name, check_inf=self.check_inf)))
def __exit__(self, *_):
for handle in self.handles:
handle.remove()
if __name__ == "__main__":
class ToNaN(nn.Module):
def forward(self, x):
return x + torch.nan
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.layer1 = nn.Sequential(
ToNaN(),
nn.Linear(10, 2),
)
def forward(self, x):
return self.layer1(x)
model = MyNetwork()
with forward_detect_anomaly(model):
x = torch.randn(2, 10)
y = model(x)
print(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment