Created
October 17, 2021 09:56
-
-
Save martenlienen/200cd9298871698a1ab3ca7732fb210c to your computer and use it in GitHub Desktop.
Utilities to inspect the backwards tree in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import sys | |
from collections.abc import Iterable | |
from dataclasses import dataclass | |
import torch | |
from calmsize import size | |
def get_tensors(): | |
return [obj for obj in gc.get_objects() if torch.is_tensor(obj)] | |
def largest_tensor(): | |
tensors = get_tensors() | |
tensors.sort(key=tensor_size) | |
return tensors[-1] | |
def largest_tensor_ref_count(): | |
largest = largest_tensor() | |
gc.collect() | |
return sys.getrefcount(largest) | |
def largest_tensor_tree(): | |
largest = largest_tensor() | |
gc.collect() | |
return get_ref_tree(largest, levels=1) | |
def largest_tensor_referrers(): | |
largest = largest_tensor() | |
gc.collect() | |
return gc.get_referrers(largest) | |
def get_backward_graph_tensors(): | |
return [t for t in get_tensors() if t.grad_fn is not None] | |
def get_referencing_objects(needle): | |
def safe_has_attr(obj, attr): | |
try: | |
return hasattr(obj, attr) | |
except: | |
return False | |
def safe_get_attr(obj, attr): | |
if safe_has_attr(obj, attr): | |
try: | |
return getattr(obj, attr) | |
except AttributeError: | |
return None | |
except RuntimeError: | |
return None | |
else: | |
return None | |
return [ | |
obj | |
for obj in gc.get_objects() | |
if ( | |
safe_has_attr(obj, "__dict__") | |
and any(val is needle for val in obj.__dict__.values()) | |
) | |
or ( | |
safe_has_attr(obj, "__slots__") | |
and ( | |
( | |
isinstance(obj.__slots__, str) | |
and safe_get_attr(obj, obj.__slots__) is needle | |
) | |
or ( | |
isinstance(obj.__slots__, Iterable) | |
and any( | |
safe_get_attr(obj, attr) is needle for attr in obj.__slots__ | |
) | |
) | |
) | |
) | |
] | |
def get_ref_tree(root, levels=1): | |
seen = set([id(root)]) | |
def recurse(node, left): | |
seen.add(id(node)) | |
parents = get_referencing_objects(node) | |
if left > 0: | |
return [(p, recurse(p, left - 1)) for p in parents] | |
else: | |
return [(p, "seen") for p in parents] | |
return recurse(root, levels) | |
def get_objects_and_holding_tensor_size(): | |
tensors = {id(t): t for t in get_tensors()} | |
holders = [ | |
(obj, attr) | |
for obj in gc.get_objects() | |
if hasattr(obj, "__dict__") | |
for attr, val in obj.__dict__.items() | |
if id(val) in tensors | |
] | |
totals = {} | |
for obj, attr in holders: | |
if id(obj) in totals: | |
_, attrs, attrs_size = totals[id(obj)] | |
totals[id(obj)] = ( | |
obj, | |
attrs + [attr], | |
attrs_size + tensor_size(getattr(obj, attr)), | |
) | |
else: | |
totals[id(obj)] = (obj, [attr], tensor_size(getattr(obj, attr))) | |
objs = list(totals.values()) | |
objs.sort(key=lambda item: item[-1]) | |
return objs | |
@dataclass | |
class Holder: | |
obj: object | |
attr: str | |
def __repr__(self): | |
return f"<Holder {type(self.obj)} {self.attr}>" | |
def get(self): | |
return getattr(self.obj, self.attr) | |
def get_objects_holding_tensors(): | |
bw_tensor_ids = set([id(t) for t in get_tensors()]) | |
return [ | |
Holder(obj, attr) | |
for obj in gc.get_objects() | |
if hasattr(obj, "__dict__") | |
for attr, val in obj.__dict__.items() | |
if id(val) in bw_tensor_ids | |
] | |
def get_objects_holding_backward_tensors(): | |
bw_tensor_ids = set([id(t) for t in get_backward_graph_tensors()]) | |
return [ | |
Holder(obj, attr) | |
for obj in gc.get_objects() | |
if hasattr(obj, "__dict__") | |
for attr, val in obj.__dict__.items() | |
if id(val) in bw_tensor_ids | |
] | |
def tensor_size(tensor): | |
return tensor.element_size() * tensor.numel() | |
def print_backward_graph(tensor, depth=5): | |
import rich | |
from rich.tree import Tree | |
root = Tree(tensor.grad_fn.name()) | |
def recurse(parent, grad_fn, level): | |
if grad_fn is None: | |
parent.add("None") | |
return | |
node = parent.add(grad_fn.name()) | |
if level < depth: | |
for nf, _ in grad_fn.next_functions: | |
recurse(node, nf, level + 1) | |
for nf, _ in tensor.grad_fn.next_functions: | |
recurse(root, nf, 0) | |
rich.print(root) | |
def total_size(tensor): | |
tensor_gfns = {t.grad_fn: t for t in get_backward_graph_tensors()} | |
def total(tns): | |
self_size = tensor_size(tns) | |
if tns.grad_fn is None: | |
return self_size | |
else: | |
backward_size = sum( | |
[ | |
total(tensor_gfns[nf]) | |
for nf, _ in tns.grad_fn.next_functions | |
if nf in tensor_gfns | |
] | |
) | |
return self_size + backward_size | |
return total(tensor) | |
def get_objects_and_total_bw_graph_size(): | |
return [ | |
(obj, attr, size(total_size(getattr(obj, attr)))) | |
for obj, attr in get_objects_holding_backward_tensors() | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment