Skip to content

Instantly share code, notes, and snippets.

@tvogels
Created December 22, 2021 18:48
Show Gist options
  • Save tvogels/9e2720e74da10abc0e9aa134fdc90e15 to your computer and use it in GitHub Desktop.
Save tvogels/9e2720e74da10abc0e9aa134fdc90e15 to your computer and use it in GitHub Desktop.
PyTorch timer / profiler
import time
import contextlib
import queue
import threading
from collections import defaultdict
from typing import List, NamedTuple
import pandas as pd
import torch
@contextlib.contextmanager
def profiler():
prof = Profiler()
try:
yield prof
finally:
prof.stop()
class Profiler:
def __init__(self):
self.is_stopped = False
self.queue = queue.SimpleQueue()
self.thread = threading.Thread(
target=_profiling_worker, args=(self.queue, ), daemon=True
)
self.thread.start()
self.measurements = defaultdict(list)
def _measure_cuda(self, name: str = None, append_to: List = None):
if self.is_stopped:
raise RuntimeError("Profiler was stopped already.")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
elapsed_time = torch.futures.Future()
start.record()
yield elapsed_time
end.record()
if append_to is not None:
append_to.append(elapsed_time)
if name is not None:
self.measurements[name].append(elapsed_time)
self.queue.put(ProfilingQueueEntry(start, end, elapsed_time))
def _measure_cpu(self, name: str = None, append_to: List = None):
if self.is_stopped:
raise RuntimeError("Profiler was stopped already.")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
elapsed_time = torch.futures.Future()
start_time = time.time_ns()
yield elapsed_time
end_time = time.time_ns()
if append_to is not None:
append_to.append(elapsed_time)
if name is not None:
self.measurements[name].append(elapsed_time)
elapsed_time.set_result((end_time - start_time) / 1_000_000)
@contextlib.contextmanager
def measure(self, name: str = None, append_to: List = None):
if torch.cuda.is_available():
return self._measure_cuda(name, append_to=append_to)
else:
return self._measure_cpu(name, append_to=append_to)
def stop(self):
self.is_stopped = True
self.queue.put(None)
self.thread.join()
def results(self):
results = []
for name in self.measurements:
for i, duration in enumerate(torch.futures.wait_all(self.measurements[name])):
results.append({"event": name, "occurrence": i, "duration": duration})
return pd.DataFrame(results)
class ProfilingQueueEntry(NamedTuple):
start: torch.cuda.Event
end: torch.cuda.Event
future: torch.futures.Future
class Measurement(NamedTuple):
name: str
duration: float
def _profiling_worker(task_queue):
while True:
task = task_queue.get()
if task == None:
return
start, end, future = task
end.synchronize()
future.set_result(start.elapsed_time(end))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment