Skip to content

Instantly share code, notes, and snippets.

@stas00
Created January 10, 2019 03:03
Show Gist options
  • Save stas00/c3cf248bdf6b0c846005faa055451560 to your computer and use it in GitHub Desktop.
Save stas00/c3cf248bdf6b0c846005faa055451560 to your computer and use it in GitHub Desktop.
PeakMemMetric - custom fastai metric that prints gpu/cpu ram consumption and peak info per training epoch
import tracemalloc, threading, torch, time, pynvml
from fastai.utils.mem import *
from fastai.vision import *
if not torch.cuda.is_available(): raise Exception("pytorch is required")
def preload_pytorch():
torch.ones((1, 1)).cuda()
def gpu_mem_get_used_no_cache():
torch.cuda.empty_cache()
return gpu_mem_get().used
def gpu_mem_used_get_fast(gpu_handle):
info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
return int(info.used/2**20)
def torch_mem_report():
torch.cuda.empty_cache()
print(list(map(lambda x: int(x/2**20), [torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated(), torch.cuda.memory_cached(), torch.cuda.max_memory_cached()])))
preload_pytorch()
pynvml.nvmlInit()
class PeakMemMetric(LearnerCallback):
_order=-20 # Needs to run before the recorder
def peak_monitor_start(self):
self.peak_monitoring = True
# start RAM tracing
tracemalloc.start()
# this thread samples RAM usage as long as the current epoch of the fit loop is running
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
def peak_monitor_stop(self):
tracemalloc.stop()
self.peak_monitoring = False
def peak_monitor_func(self):
self.gpu_mem_used_peak = -1
gpu_id = torch.cuda.current_device()
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
while True:
gpu_mem_used = gpu_mem_used_get_fast(gpu_handle)
self.gpu_mem_used_peak = max(gpu_mem_used, self.gpu_mem_used_peak)
if not self.peak_monitoring: break
time.sleep(0.001) # 1msec
def on_train_begin(self, **kwargs):
self.learn.recorder.add_metric_names(['cpu used', 'peak', 'gpu used', 'peak'])
def on_epoch_begin(self, **kwargs):
self.peak_monitor_start()
self.gpu_before = gpu_mem_get_used_no_cache()
def on_epoch_end(self, **kwargs):
cpu_current, cpu_peak = list(map(lambda x: int(x/2**20), tracemalloc.get_traced_memory()))
gpu_current = gpu_mem_get_used_no_cache() - self.gpu_before
gpu_peak = self.gpu_mem_used_peak - self.gpu_before
self.peak_monitor_stop()
# The numbers are deltas in MBs (beginning of the epoch and the end)
self.learn.recorder.add_metrics([cpu_current, cpu_peak, gpu_current, gpu_peak])
learn = create_cnn(data, model, metrics=[accuracy], callback_fns=PeakMemMetric)
learn.fit_one_cycle(1, max_lr=1e-2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment