Created
August 21, 2020 17:37
-
-
Save mfkasim1/9120b0864d52fecb59b1f2cbfacc3d82 to your computer and use it in GitHub Desktop.
Debugging multilevel autograd for anomaly detection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Gradph object can be used to debug multi-level autograd for anomaly detection. | |
To see an example on how to use this, see the bottom of this script. | |
""" | |
import torch | |
class GradphNode(object): | |
def __init__(self, gfn, parent_node): | |
self._gfn = gfn | |
self._children = [] | |
self._parent = parent_node | |
if parent_node is not None: | |
self._parent._add_child(self) | |
def _add_child(self, child): | |
self._children.append(child) | |
def children(self): | |
return self._children | |
def parent(self): | |
return self._parent | |
def gfn(self): | |
return self._gfn | |
class Gradph(object): | |
def __init__(self, filter="nan"): | |
self.gfn2node = {} | |
self.gfnset = set() | |
self._rootnodes = [] | |
if isinstance(filter, str): | |
if filter == "nan": | |
self._filterfcn = lambda g, gfn: g is not None and torch.any(torch.isnan(g)) | |
elif filter == "inf": | |
self._filterfcn = lambda g, gfn: g is not None and torch.any(torch.isinf(g)) | |
else: | |
raise RuntimeError("Unknown filter: %s" % filter) | |
elif hasattr(filter, "__call__"): | |
self._filterfcn = filter | |
def gradhook(self, rootgfn): | |
# if the input is a tensor, then get grad_fn | |
if isinstance(rootgfn, torch.Tensor): | |
if hasattr(rootgfn, "grad_fn"): | |
rootgfn = rootgfn.grad_fn | |
else: | |
raise TypeError("The input must be non-leaf tensor or tensor.grad_fn") | |
# make unregistered nodes as roots | |
def reg_root_mnodes(gfn): | |
if self._has(gfn): | |
return | |
self._add(gfn, parent_gfn=None) | |
_iter_graph(rootgfn, reg_root_mnodes) | |
# register the backward_hook to assign the new gfn in the gradph | |
def callbackfcn(gfn): | |
def _hook(grad_outputs, grad_inputs): | |
# perform anomaly detection, raise an error if detected | |
self._anomaly_detection(gfn, grad_outputs, grad_inputs) | |
# collect all new grad_fn and assign this gfn as the parent of the new gfns | |
for output in grad_outputs: | |
if output is None or output.grad_fn is None: | |
continue | |
new_gfns = self._get_new_gfns(output.grad_fn) # O(n^2) (!!!) | |
# register the new gfns and meta nodes | |
for f in new_gfns: | |
self._add(f, parent_gfn=gfn) | |
gfn.register_hook(_hook) | |
_iter_graph(rootgfn, callbackfcn) | |
return self | |
def rootnodes(self): | |
return self._rootnodes | |
def _anomaly_detection(self, gfn, grad_outputs, grad_inputs): | |
inps_normal = [not self._filterfcn(g, gfn) for g in grad_inputs] | |
outs_normal = [not self._filterfcn(g, gfn) for g in grad_outputs] | |
inp_normal = all(inps_normal) | |
out_normal = all(outs_normal) | |
# craft the message to be informative | |
if not (out_normal and inp_normal): | |
inout_msgs = [] | |
_get_false_idxs = lambda lst: [i for i,e in enumerate(lst) if not e] | |
if not out_normal: | |
idxs = _get_false_idxs(outs_normal) | |
inout_msgs.append("%s-th output" % str(idxs)) | |
if not inp_normal: | |
idxs = _get_false_idxs(inps_normal) | |
inout_msgs.append("%s-th input" % str(idxs)) | |
inout_msg = " and ".join(inout_msgs) | |
msg = "\n------------------------------------------------" | |
msg += "\nanomaly detected at %s of function: %s at %s" % (inout_msg, gfn.name(), id(gfn)) | |
msg += "\nGrad_inputs:" | |
shape2str = lambda shape: "(%s)" % (",".join([str(s) for s in shape])) | |
for i,g in enumerate(grad_inputs): | |
if g is None: | |
continue | |
msg += "\n%d: %s %s" % (i, shape2str(g.shape), g) | |
msg += "\nGrad_outputs:" | |
for i,g in enumerate(grad_outputs): | |
if g is None: | |
continue | |
msg += "\n%d: %s %s" % (i, shape2str(g.shape), g) | |
msg += "\n------------------- parents -------------------" | |
level = 0 | |
node = self._getnode(gfn) | |
while node is not None: | |
nodegfn = node.gfn() | |
msg += ("\n(grand^%d)-parent:" % (level-1)) if level > 0 else "\nNode:" | |
msg += "\n * %s at %d" % (nodegfn.name(), id(nodegfn)) | |
# print traceback | |
tbs = nodegfn.metadata["traceback_"] | |
msg += "\n * traceback:\n" | |
for tb in tbs: | |
msg += tb | |
level += 1 | |
node = node.parent() | |
raise RuntimeError(msg) | |
def _add(self, gfn, parent_gfn): | |
self.gfnset.add(gfn) | |
if parent_gfn is not None: | |
parent_node = self._getnode(parent_gfn) | |
else: | |
parent_node = None | |
node = GradphNode(gfn, parent_node=parent_node) | |
self.gfn2node[gfn] = node | |
# mark the rootnodes | |
if parent_node is None: | |
self._rootnodes.append(node) | |
def _getnode(self, gfn): | |
return self.gfn2node[gfn] | |
def _has(self, gfn): | |
return gfn in self.gfnset | |
def _get_new_gfns(self, output_gfn): | |
# currently the way it searches the new grad_fn is by: | |
# 1. iterating all grad_fn in the graph | |
# 2. finding the grad_fn which is not registered in the graph already | |
# this is too slow because it will run in O(n^2) complexity | |
res = [] | |
def callbackfcn(gfn): | |
if self._has(gfn): | |
return | |
res.append(gfn) | |
_iter_graph(output_gfn, callbackfcn) | |
return res | |
# thanks to Joel Richard: https://discuss.pytorch.org/t/tracking-down-nan-gradients/78112 | |
def _iter_graph(rootgfn, callbackfcn): | |
if rootgfn is None: | |
return | |
queue = [rootgfn] | |
seen = set() | |
while queue: | |
gfn = queue.pop(0) | |
if gfn in seen: | |
continue | |
seen.add(gfn) | |
children = [] | |
for next_gfn, _ in gfn.next_functions: | |
if next_gfn is None: | |
continue | |
queue.append(next_gfn) | |
children.append(next_gfn) | |
callbackfcn(gfn) | |
############## example ############## | |
def example(with_gradph=True): | |
x = torch.tensor(1.0, requires_grad=True) | |
y = torch.tensor(1e-8, requires_grad=True) # zero to induce nan in 3rd backward | |
a = x + y | |
b = x * y | |
z = a / b | |
if with_gradph: | |
gradph = Gradph(filter="nan") | |
gradph.gradhook(z) | |
gy, = torch.autograd.grad(z, (y,), create_graph=True) | |
if with_gradph: | |
gradph.gradhook(gy) | |
gy2, = torch.autograd.grad(gy, (y,), create_graph=True) | |
if with_gradph: | |
gradph.gradhook(gy2) | |
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True) | |
if with_gradph: | |
gradph.gradhook(gy3) | |
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True) | |
return gy4 | |
if __name__ == "__main__": | |
with torch.autograd.detect_anomaly(): | |
example(with_gradph=1) # try 0 and 1 to see the difference |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment