Skip to content

Instantly share code, notes, and snippets.

@gothos-folly
Created June 15, 2018 07:34
Show Gist options
  • Save gothos-folly/123b70f0cf992828f0d084673bf07d0d to your computer and use it in GitHub Desktop.
Save gothos-folly/123b70f0cf992828f0d084673bf07d0d to your computer and use it in GitHub Desktop.
Pytorch Print Tensors
import torch
import gc
def tensor_meta_data(tensor):
element_count = 1;
for dim in tensor.size():
element_count = element_count * dim
size_in_bytes = element_count * tensor.element_size()
dtype = str(tensor.dtype).replace("torch.", "")
size = str(tensor.size()).replace("torch.Size(", "").replace(")", "")
return f"{size_in_bytes/1000000:5.1f}MB {dtype}{size} {type(tensor).__name__} {tensor.device}"
def print_tensors():
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(tensor_meta_data(obj))
@gothos-folly
Copy link
Author

Helper function to print the size and memory consumption of tensors in pytorch to help debug out-of-memory errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment