Skip to content

Instantly share code, notes, and snippets.

@martenlienen
Created October 17, 2021 09:56
Show Gist options
  • Save martenlienen/200cd9298871698a1ab3ca7732fb210c to your computer and use it in GitHub Desktop.
Save martenlienen/200cd9298871698a1ab3ca7732fb210c to your computer and use it in GitHub Desktop.
Utilities to inspect the backwards tree in pytorch
import gc
import sys
from collections.abc import Iterable
from dataclasses import dataclass
import torch
from calmsize import size
def get_tensors():
return [obj for obj in gc.get_objects() if torch.is_tensor(obj)]
def largest_tensor():
tensors = get_tensors()
tensors.sort(key=tensor_size)
return tensors[-1]
def largest_tensor_ref_count():
largest = largest_tensor()
gc.collect()
return sys.getrefcount(largest)
def largest_tensor_tree():
largest = largest_tensor()
gc.collect()
return get_ref_tree(largest, levels=1)
def largest_tensor_referrers():
largest = largest_tensor()
gc.collect()
return gc.get_referrers(largest)
def get_backward_graph_tensors():
return [t for t in get_tensors() if t.grad_fn is not None]
def get_referencing_objects(needle):
def safe_has_attr(obj, attr):
try:
return hasattr(obj, attr)
except:
return False
def safe_get_attr(obj, attr):
if safe_has_attr(obj, attr):
try:
return getattr(obj, attr)
except AttributeError:
return None
except RuntimeError:
return None
else:
return None
return [
obj
for obj in gc.get_objects()
if (
safe_has_attr(obj, "__dict__")
and any(val is needle for val in obj.__dict__.values())
)
or (
safe_has_attr(obj, "__slots__")
and (
(
isinstance(obj.__slots__, str)
and safe_get_attr(obj, obj.__slots__) is needle
)
or (
isinstance(obj.__slots__, Iterable)
and any(
safe_get_attr(obj, attr) is needle for attr in obj.__slots__
)
)
)
)
]
def get_ref_tree(root, levels=1):
seen = set([id(root)])
def recurse(node, left):
seen.add(id(node))
parents = get_referencing_objects(node)
if left > 0:
return [(p, recurse(p, left - 1)) for p in parents]
else:
return [(p, "seen") for p in parents]
return recurse(root, levels)
def get_objects_and_holding_tensor_size():
tensors = {id(t): t for t in get_tensors()}
holders = [
(obj, attr)
for obj in gc.get_objects()
if hasattr(obj, "__dict__")
for attr, val in obj.__dict__.items()
if id(val) in tensors
]
totals = {}
for obj, attr in holders:
if id(obj) in totals:
_, attrs, attrs_size = totals[id(obj)]
totals[id(obj)] = (
obj,
attrs + [attr],
attrs_size + tensor_size(getattr(obj, attr)),
)
else:
totals[id(obj)] = (obj, [attr], tensor_size(getattr(obj, attr)))
objs = list(totals.values())
objs.sort(key=lambda item: item[-1])
return objs
@dataclass
class Holder:
obj: object
attr: str
def __repr__(self):
return f"<Holder {type(self.obj)} {self.attr}>"
def get(self):
return getattr(self.obj, self.attr)
def get_objects_holding_tensors():
bw_tensor_ids = set([id(t) for t in get_tensors()])
return [
Holder(obj, attr)
for obj in gc.get_objects()
if hasattr(obj, "__dict__")
for attr, val in obj.__dict__.items()
if id(val) in bw_tensor_ids
]
def get_objects_holding_backward_tensors():
bw_tensor_ids = set([id(t) for t in get_backward_graph_tensors()])
return [
Holder(obj, attr)
for obj in gc.get_objects()
if hasattr(obj, "__dict__")
for attr, val in obj.__dict__.items()
if id(val) in bw_tensor_ids
]
def tensor_size(tensor):
return tensor.element_size() * tensor.numel()
def print_backward_graph(tensor, depth=5):
import rich
from rich.tree import Tree
root = Tree(tensor.grad_fn.name())
def recurse(parent, grad_fn, level):
if grad_fn is None:
parent.add("None")
return
node = parent.add(grad_fn.name())
if level < depth:
for nf, _ in grad_fn.next_functions:
recurse(node, nf, level + 1)
for nf, _ in tensor.grad_fn.next_functions:
recurse(root, nf, 0)
rich.print(root)
def total_size(tensor):
tensor_gfns = {t.grad_fn: t for t in get_backward_graph_tensors()}
def total(tns):
self_size = tensor_size(tns)
if tns.grad_fn is None:
return self_size
else:
backward_size = sum(
[
total(tensor_gfns[nf])
for nf, _ in tns.grad_fn.next_functions
if nf in tensor_gfns
]
)
return self_size + backward_size
return total(tensor)
def get_objects_and_total_bw_graph_size():
return [
(obj, attr, size(total_size(getattr(obj, attr))))
for obj, attr in get_objects_holding_backward_tensors()
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment