Skip to content

Instantly share code, notes, and snippets.

@apaszke
Last active April 19, 2024 16:31
Show Gist options
  • Save apaszke/f93a377244be9bfcb96d3547b9bc424d to your computer and use it in GitHub Desktop.
Save apaszke/f93a377244be9bfcb96d3547b9bc424d to your computer and use it in GitHub Desktop.
from graphviz import Digraph
import torch
from torch.autograd import Variable, Function
def iter_graph(root, callback):
queue = [root]
seen = set()
while queue:
fn = queue.pop()
if fn in seen:
continue
seen.add(fn)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
queue.append(next_fn)
callback(fn)
def register_hooks(var):
fn_dict = {}
def hook_cb(fn):
def register_grad(grad_input, grad_output):
fn_dict[fn] = grad_input
fn.register_hook(register_grad)
iter_graph(var.grad_fn, hook_cb)
def is_bad_grad(grad_output):
grad_output = grad_output.data
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
def make_dot():
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
def size_to_str(size):
return '('+(', ').join(map(str, size))+')'
def build_graph(fn):
if hasattr(fn, 'variable'): # if GradAccumulator
u = fn.variable
node_name = 'Variable\n ' + size_to_str(u.size())
dot.node(str(id(u)), node_name, fillcolor='lightblue')
else:
assert fn in fn_dict, fn
fillcolor = 'white'
if any(is_bad_grad(gi) for gi in fn_dict[fn]):
fillcolor = 'red'
dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
next_id = id(getattr(next_fn, 'variable', next_fn))
dot.edge(str(next_id), str(id(fn)))
iter_graph(var.grad_fn, build_graph)
return dot
return make_dot
if __name__ == '__main__':
x = Variable(torch.randn(10, 10), requires_grad=True)
y = Variable(torch.randn(10, 10), requires_grad=True)
z = x / (y * 0)
z = z.sum() * 2
get_dot = register_hooks(z)
z.backward()
dot = get_dot()
dot.save('tmp.dot')
@david-waterworth
Copy link

This fails with

AttributeError: 'NoneType' object has no attribute 'data'

on the first line of is_bad_grad using torch==1.9.1+cu111, any suggestions on what to change?

@psampathkumar
Copy link

psampathkumar commented Jul 18, 2022

probably change is_bad_grad to

def is_bad_grad(grad_output):
        if grad_output is None:
                return True
        grad_output = grad_output.data
        return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()

recent pytorch sets zeros as None sometimes to speedup the process of accumulating gradients. Return true or false, based on whether you consider zero a good or bad gradient.

@raquelhortab
Copy link

I run out of ram using this code, and if I try only running it once after n iterations, it crashes.
However, I got some interesting graph before the ram ran out. Does anyone have any suggestion on how to approach the problem? The graph is very very large, and there are red nodes everywhere but this is the end of it:

Selection_763

My model is not very complicated (apart from the transformer itself):

# def __init__
self.bert = transformers.BertModel.from_pretrained(config.BASE_MODEL_PATH, return_dict=False)
self.bert_drop_1 = nn.Dropout(0.3)
self.out_labels = nn.Linear(768, 1) 
self.sigmoid = nn.Sigmoid()
#... def forward
o1, o2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
dropout = self.bert_drop_1(o2)
logits = self.out_labels(dropout)
labels = self.sigmoid(logits)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment