Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Compute total memory consumed by PyTorch tensors
#!/usr/bin/env python3
"""
Compute total memory consumed by PyTorch tensors.
"""
import gc
import torch
def get_used_mem():
host_mem = 0
gpu_mem = 0
for obj in gc.get_objects():
if torch.is_tensor(obj):
mem = obj.numel()*obj.element_size()
if obj.is_cuda:
gpu_mem += mem
else:
host_mem += mem
return host_mem, gpu_mem
if __name__ == '__main__':
host_mem, gpu_mem = get_used_mem()
print(host_mem, gpu_mem)
x = torch.empty((10, 10), dtype=torch.float32)
host_mem, gpu_mem = get_used_mem()
print(host_mem, gpu_mem)
y = torch.empty((10, 10), dtype=torch.float32).to('cuda:0')
host_mem, gpu_mem = get_used_mem()
print(host_mem, gpu_mem)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.