Skip to content

Instantly share code, notes, and snippets.

@vgoklani
Created February 27, 2023 15:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vgoklani/942e0477e9281ee15ef6210cfff6236f to your computer and use it in GitHub Desktop.
Save vgoklani/942e0477e9281ee15ef6210cfff6236f to your computer and use it in GitHub Desktop.
from typing import Optional
import torch
import torch.nn as nn
@torch.no_grad()
def measure_time_device(
model: nn.Module,
dtype: Optional[torch.dtype] = torch.float32,
num_repeats: Optional[int] = 100,
num_warmups: Optional[int] = 10,
synchronize: Optional[bool] = True,
continuous_measure: Optional[bool] = True,
**kwargs,
) -> float:
"""https://leimao.github.io/blog/PyTorch-Benchmark"""
for _ in range(num_warmups):
with torch.autocast(device_type="cuda", dtype=dtype):
_ = model.forward(**kwargs)
torch.cuda.synchronize()
elapsed_time_ms = 0
if continuous_measure:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_repeats):
with torch.autocast(device_type="cuda", dtype=dtype):
_ = model.forward(**kwargs)
end_event.record()
if synchronize:
# This has to be synchronized to compute the elapsed time.
# Otherwise, there will be runtime error.
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
else:
for _ in range(num_repeats):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
with torch.autocast(device_type="cuda", dtype=dtype):
_ = model.forward(**kwargs)
end_event.record()
if synchronize:
# This has to be synchronized to compute the elapsed time.
# Otherwise, there will be runtime error.
torch.cuda.synchronize()
elapsed_time_ms += start_event.elapsed_time(end_event)
return elapsed_time_ms / num_repeats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment