Skip to content

Instantly share code, notes, and snippets.

@Stonesjtu
Last active March 7, 2023 16:58
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save Stonesjtu/368ddf5d9eb56669269ecdf9b0d21cbe to your computer and use it in GitHub Desktop.
Save Stonesjtu/368ddf5d9eb56669269ecdf9b0d21cbe to your computer and use it in GitHub Desktop.
A simple Pytorch memory usages profiler
import gc
import torch
## MEM utils ##
def mem_report():
'''Report the memory usage of the tensor.storage in pytorch
Both on CPUs and GPUs are reported'''
def _mem_report(tensors, mem_type):
'''Print the selected tensors of type
There are two major storage types in our major concern:
- GPU: tensors transferred to CUDA devices
- CPU: tensors remaining on the system memory (usually unimportant)
Args:
- tensors: the tensors of specified type
- mem_type: 'CPU' or 'GPU' in current implementation '''
print('Storage on %s' %(mem_type))
print('-'*LEN)
total_numel = 0
total_mem = 0
visited_data = []
for tensor in tensors:
if tensor.is_sparse:
continue
# a data_ptr indicates a memory block allocated
data_ptr = tensor.storage().data_ptr()
if data_ptr in visited_data:
continue
visited_data.append(data_ptr)
numel = tensor.storage().size()
total_numel += numel
element_size = tensor.storage().element_size()
mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte
total_mem += mem
element_type = type(tensor).__name__
size = tuple(tensor.size())
print('%s\t\t%s\t\t%.2f' % (
element_type,
size,
mem) )
print('-'*LEN)
print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) )
print('-'*LEN)
LEN = 65
print('='*LEN)
objects = gc.get_objects()
print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
tensors = [obj for obj in objects if torch.is_tensor(obj)]
cuda_tensors = [t for t in tensors if t.is_cuda]
host_tensors = [t for t in tensors if not t.is_cuda]
_mem_report(cuda_tensors, 'GPU')
_mem_report(host_tensors, 'CPU')
print('='*LEN)
@kkonevets
Copy link

kkonevets commented Apr 2, 2019

Hey man, use
visited_data = set()
and
visited_data.update([data_ptr])
for much faster loop

@Stonesjtu
Copy link
Author

@kkonevets, thanks, didn't pay much attention to the scenario when # tensors goes high. But I think the in operation won't take much time in this script, how do you think.

@Stonesjtu
Copy link
Author

I wrote a more powerful and pip installable tool recently, you can check this out:

https://github.com/stonesjtu/pytorch_memlab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment