Created
August 31, 2021 22:49
-
-
Save xmodar/77c6743230d98164aceeb2364472c238 to your computer and use it in GitHub Desktop.
Tool to monitor used GPU memory (bytes) and time (nanseconds) in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import math | |
import time | |
import datetime | |
from contextlib import contextmanager | |
import torch | |
class Monitor: | |
"""Tool to monitor used GPU memory (bytes) and time (nanseconds)""" | |
def __init__(self, device=None, max_time=0.2e9, bases=(1, 2, 5), power=10): | |
self.bases = bases | |
self.power = power | |
self.device = device | |
self.max_time = max_time | |
self.count = self.time = self.memory = 0 | |
self.iterator = self.get_iterator() | |
def get_iterator(self): | |
"""Get an iterator over the repetitions""" | |
factor = 1 | |
while True: | |
for base in self.bases: | |
iterations = base * factor | |
yield False | |
with self.monitored() as stats: | |
for _ in range(iterations): | |
self.count += 1 | |
yield True | |
self.memory += stats['memory'] * iterations | |
self.time += stats['time'] | |
factor *= self.power | |
@contextmanager | |
def monitored(self): | |
"""Context manager that monitors GPU memory and elapsed time""" | |
gc_old = gc.isenabled() | |
gc.disable() | |
try: | |
torch.cuda.synchronize(self.device) | |
torch.cuda.reset_peak_memory_stats(self.device) | |
stats = dict( | |
memory=-torch.cuda.max_memory_allocated(self.device), | |
time=-time.time_ns(), | |
) | |
yield stats | |
finally: | |
torch.cuda.synchronize(self.device) | |
stats['time'] += time.time_ns() | |
stats['memory'] += torch.cuda.max_memory_allocated(self.device) | |
if gc_old: | |
gc.enable() | |
def __bool__(self): | |
if next(self.iterator): | |
return True | |
if self.time >= self.max_time: | |
return False | |
return next(self.iterator) | |
@staticmethod | |
def format_memory(num_bytes): | |
"""Format memory""" | |
assert num_bytes >= 0 | |
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB') | |
power = min(int(math.log(max(num_bytes, 1), 1024)), len(units) - 1) | |
return f'{num_bytes * 1024**-power:.3f} {units[power]}' | |
@staticmethod | |
def format_time(nanoseconds): | |
"""Formatted time""" | |
return f'{datetime.timedelta(microseconds=nanoseconds * 1e-3)}' | |
def __repr__(self): | |
nice_time = self.format_time(self.time / self.count) | |
nice_memory = self.format_memory(self.memory / self.count) | |
return f'x{self.count}: {nice_memory} @ {nice_time}' | |
@classmethod | |
def model_size(cls, model, formatted=True): | |
"""Get the size of a PyTorch model""" | |
total = 0 | |
for tensor in model.state_dict().values(): | |
total += tensor.storage().size() * tensor.element_size() | |
if formatted: | |
return cls.format_memory(total) | |
return total | |
def __test_monitor(device=0, batch_size=128, train=False, requires_grad=False): | |
"""Usage example for Monitor on ResNet50""" | |
from torchvision.models import resnet50 | |
model = resnet50().cuda(device) | |
model.train(train).requires_grad_(requires_grad) | |
images = torch.randn(batch_size, 3, 224, 224).cuda(device) | |
monitor = Monitor(device=device) | |
while monitor: | |
model(images) | |
print(f'ResNet50 (params = {monitor.model_size(model)}) {monitor}') | |
pytorch = f'PyTorch v{torch.__version__}' | |
# deterministic = torch.are_deterministic_algorithms_enabled() | |
cudnn = f'cuDNN v{torch.backends.cudnn.version()} ' | |
cudnn += 'enabled' if torch.backends.cudnn.enabled else 'disabled' | |
# benchmark = torch.backends.cudnn.benchmark | |
gpu = torch.cuda.get_device_properties(device).name | |
threads = f'{torch.get_num_threads()} threads' | |
# from torch.utils.collect_env import get_pretty_env_info | |
# print(get_pretty_env_info()) | |
print(f'Using {pytorch} & {cudnn} on `{gpu}` with {threads}') | |
if __name__ == '__main__': | |
__test_monitor() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment