Skip to content

Instantly share code, notes, and snippets.

@mfkasim1
Created August 21, 2020 17:37
Show Gist options
  • Save mfkasim1/9120b0864d52fecb59b1f2cbfacc3d82 to your computer and use it in GitHub Desktop.
Save mfkasim1/9120b0864d52fecb59b1f2cbfacc3d82 to your computer and use it in GitHub Desktop.
Debugging multilevel autograd for anomaly detection
"""
Gradph object can be used to debug multi-level autograd for anomaly detection.
To see an example on how to use this, see the bottom of this script.
"""
import torch
class GradphNode(object):
def __init__(self, gfn, parent_node):
self._gfn = gfn
self._children = []
self._parent = parent_node
if parent_node is not None:
self._parent._add_child(self)
def _add_child(self, child):
self._children.append(child)
def children(self):
return self._children
def parent(self):
return self._parent
def gfn(self):
return self._gfn
class Gradph(object):
def __init__(self, filter="nan"):
self.gfn2node = {}
self.gfnset = set()
self._rootnodes = []
if isinstance(filter, str):
if filter == "nan":
self._filterfcn = lambda g, gfn: g is not None and torch.any(torch.isnan(g))
elif filter == "inf":
self._filterfcn = lambda g, gfn: g is not None and torch.any(torch.isinf(g))
else:
raise RuntimeError("Unknown filter: %s" % filter)
elif hasattr(filter, "__call__"):
self._filterfcn = filter
def gradhook(self, rootgfn):
# if the input is a tensor, then get grad_fn
if isinstance(rootgfn, torch.Tensor):
if hasattr(rootgfn, "grad_fn"):
rootgfn = rootgfn.grad_fn
else:
raise TypeError("The input must be non-leaf tensor or tensor.grad_fn")
# make unregistered nodes as roots
def reg_root_mnodes(gfn):
if self._has(gfn):
return
self._add(gfn, parent_gfn=None)
_iter_graph(rootgfn, reg_root_mnodes)
# register the backward_hook to assign the new gfn in the gradph
def callbackfcn(gfn):
def _hook(grad_outputs, grad_inputs):
# perform anomaly detection, raise an error if detected
self._anomaly_detection(gfn, grad_outputs, grad_inputs)
# collect all new grad_fn and assign this gfn as the parent of the new gfns
for output in grad_outputs:
if output is None or output.grad_fn is None:
continue
new_gfns = self._get_new_gfns(output.grad_fn) # O(n^2) (!!!)
# register the new gfns and meta nodes
for f in new_gfns:
self._add(f, parent_gfn=gfn)
gfn.register_hook(_hook)
_iter_graph(rootgfn, callbackfcn)
return self
def rootnodes(self):
return self._rootnodes
def _anomaly_detection(self, gfn, grad_outputs, grad_inputs):
inps_normal = [not self._filterfcn(g, gfn) for g in grad_inputs]
outs_normal = [not self._filterfcn(g, gfn) for g in grad_outputs]
inp_normal = all(inps_normal)
out_normal = all(outs_normal)
# craft the message to be informative
if not (out_normal and inp_normal):
inout_msgs = []
_get_false_idxs = lambda lst: [i for i,e in enumerate(lst) if not e]
if not out_normal:
idxs = _get_false_idxs(outs_normal)
inout_msgs.append("%s-th output" % str(idxs))
if not inp_normal:
idxs = _get_false_idxs(inps_normal)
inout_msgs.append("%s-th input" % str(idxs))
inout_msg = " and ".join(inout_msgs)
msg = "\n------------------------------------------------"
msg += "\nanomaly detected at %s of function: %s at %s" % (inout_msg, gfn.name(), id(gfn))
msg += "\nGrad_inputs:"
shape2str = lambda shape: "(%s)" % (",".join([str(s) for s in shape]))
for i,g in enumerate(grad_inputs):
if g is None:
continue
msg += "\n%d: %s %s" % (i, shape2str(g.shape), g)
msg += "\nGrad_outputs:"
for i,g in enumerate(grad_outputs):
if g is None:
continue
msg += "\n%d: %s %s" % (i, shape2str(g.shape), g)
msg += "\n------------------- parents -------------------"
level = 0
node = self._getnode(gfn)
while node is not None:
nodegfn = node.gfn()
msg += ("\n(grand^%d)-parent:" % (level-1)) if level > 0 else "\nNode:"
msg += "\n * %s at %d" % (nodegfn.name(), id(nodegfn))
# print traceback
tbs = nodegfn.metadata["traceback_"]
msg += "\n * traceback:\n"
for tb in tbs:
msg += tb
level += 1
node = node.parent()
raise RuntimeError(msg)
def _add(self, gfn, parent_gfn):
self.gfnset.add(gfn)
if parent_gfn is not None:
parent_node = self._getnode(parent_gfn)
else:
parent_node = None
node = GradphNode(gfn, parent_node=parent_node)
self.gfn2node[gfn] = node
# mark the rootnodes
if parent_node is None:
self._rootnodes.append(node)
def _getnode(self, gfn):
return self.gfn2node[gfn]
def _has(self, gfn):
return gfn in self.gfnset
def _get_new_gfns(self, output_gfn):
# currently the way it searches the new grad_fn is by:
# 1. iterating all grad_fn in the graph
# 2. finding the grad_fn which is not registered in the graph already
# this is too slow because it will run in O(n^2) complexity
res = []
def callbackfcn(gfn):
if self._has(gfn):
return
res.append(gfn)
_iter_graph(output_gfn, callbackfcn)
return res
# thanks to Joel Richard: https://discuss.pytorch.org/t/tracking-down-nan-gradients/78112
def _iter_graph(rootgfn, callbackfcn):
if rootgfn is None:
return
queue = [rootgfn]
seen = set()
while queue:
gfn = queue.pop(0)
if gfn in seen:
continue
seen.add(gfn)
children = []
for next_gfn, _ in gfn.next_functions:
if next_gfn is None:
continue
queue.append(next_gfn)
children.append(next_gfn)
callbackfcn(gfn)
############## example ##############
def example(with_gradph=True):
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(1e-8, requires_grad=True) # zero to induce nan in 3rd backward
a = x + y
b = x * y
z = a / b
if with_gradph:
gradph = Gradph(filter="nan")
gradph.gradhook(z)
gy, = torch.autograd.grad(z, (y,), create_graph=True)
if with_gradph:
gradph.gradhook(gy)
gy2, = torch.autograd.grad(gy, (y,), create_graph=True)
if with_gradph:
gradph.gradhook(gy2)
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
if with_gradph:
gradph.gradhook(gy3)
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
return gy4
if __name__ == "__main__":
with torch.autograd.detect_anomaly():
example(with_gradph=1) # try 0 and 1 to see the difference
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment