Skip to content

Instantly share code, notes, and snippets.

@yifuwang
Created February 5, 2024 00:11
Show Gist options
  • Save yifuwang/178c1f4bf951c5794ea79c04d90e44fa to your computer and use it in GitHub Desktop.
Save yifuwang/178c1f4bf951c5794ea79c04d90e44fa to your computer and use it in GitHub Desktop.
from typing import Callable
import functools
import torch
SIZES = [
torch.Size([256, 280]),
torch.Size([256]),
torch.Size([280, 256]),
torch.Size([280]),
torch.Size([280]),
torch.Size([280]),
torch.Size([128, 280]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 280]),
torch.Size([320]),
torch.Size([256, 1320]),
torch.Size([256]),
torch.Size([1320, 256]),
torch.Size([1320]),
torch.Size([1320]),
torch.Size([1320]),
torch.Size([128, 1320]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1320]),
torch.Size([320]),
torch.Size([256, 360]),
torch.Size([256]),
torch.Size([360, 256]),
torch.Size([360]),
torch.Size([360]),
torch.Size([360]),
torch.Size([128, 360]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 360]),
torch.Size([320]),
torch.Size([320, 400]),
torch.Size([320]),
torch.Size([320]),
torch.Size([320]),
torch.Size([320, 400]),
torch.Size([320]),
torch.Size([256, 300]),
torch.Size([256]),
torch.Size([300, 256]),
torch.Size([300]),
torch.Size([300]),
torch.Size([300]),
torch.Size([128, 300]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 300]),
torch.Size([320]),
torch.Size([256, 1976]),
torch.Size([256]),
torch.Size([1976, 256]),
torch.Size([1976]),
torch.Size([1976]),
torch.Size([1976]),
torch.Size([128, 1976]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1976]),
torch.Size([320]),
torch.Size([256, 1992]),
torch.Size([256]),
torch.Size([1992, 256]),
torch.Size([1992]),
torch.Size([1992]),
torch.Size([1992]),
torch.Size([128, 1992]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1992]),
torch.Size([320]),
torch.Size([256, 1980]),
torch.Size([256]),
torch.Size([1980, 256]),
torch.Size([1980]),
torch.Size([1980]),
torch.Size([1980]),
torch.Size([128, 1980]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1980]),
torch.Size([320]),
torch.Size([256, 1976]),
torch.Size([256]),
torch.Size([1976, 256]),
torch.Size([1976]),
torch.Size([1976]),
torch.Size([1976]),
torch.Size([128, 1976]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1976]),
torch.Size([320]),
torch.Size([256, 1964]),
torch.Size([256]),
torch.Size([1964, 256]),
torch.Size([1964]),
torch.Size([1964]),
torch.Size([1964]),
torch.Size([128, 1964]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1964]),
torch.Size([320]),
torch.Size([256, 1968]),
torch.Size([256]),
torch.Size([1968, 256]),
torch.Size([1968]),
torch.Size([1968]),
torch.Size([1968]),
torch.Size([128, 1968]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1968]),
torch.Size([320]),
torch.Size([256, 1888]),
torch.Size([256]),
torch.Size([1888, 256]),
torch.Size([1888]),
torch.Size([1888]),
torch.Size([1888]),
torch.Size([128, 1888]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1888]),
torch.Size([320]),
torch.Size([256, 1824]),
torch.Size([256]),
torch.Size([1824, 256]),
torch.Size([1824]),
torch.Size([1824]),
torch.Size([1824]),
torch.Size([128, 1824]),
torch.Size([128]),
torch.Size([128]),
torch.Size([128]),
torch.Size([320, 1824]),
torch.Size([320]),
] * 10
# SIZES = [
# torch.Size([256, 280]),
# torch.Size([256]),
# torch.Size([280, 256]),
# torch.Size([280]),
# torch.Size([280]),
# torch.Size([280]),
# torch.Size([128, 280]),
# ]
src = [torch.rand(sz, device="cuda") for sz in SIZES]
dst = [torch.empty(sz, device="cuda") for sz in SIZES]
gb = sum(sz.numel() * 4 for sz in SIZES) / 1024 ** 3
def trace_handler(prof) -> None:
import subprocess, uuid, os
name = f"{uuid.uuid4()}.json"
prof.export_chrome_trace(name)
subprocess.check_call(["python", "/home/yifu/trace.py", name])
os.remove(name)
enable_profiler = False
def benchmark_time(
benchmark_fn: Callable,
*benchmark_fn_args,
**benchmark_fn_kwargs,
) -> int:
from contextlib import nullcontext
from torch.testing._internal.common_utils import get_cycles_per_ms
import time
if not enable_profiler:
prof = nullcontext()
else:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=trace_handler,
)
MEASURE_ITERS = 100
evt_begin = [torch.cuda.Event(enable_timing=True) for _ in range(MEASURE_ITERS)]
evt_end = [torch.cuda.Event(enable_timing=True) for _ in range(MEASURE_ITERS)]
cpu_ms = 0
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
torch.cuda.synchronize()
with prof:
for i in range(MEASURE_ITERS):
cache.zero_()
cpu_begin = time.perf_counter_ns()
torch.cuda._sleep(int(10 * get_cycles_per_ms()))
evt_begin[i].record()
benchmark_fn(*benchmark_fn_args, **benchmark_fn_kwargs)
evt_end[i].record()
cpu_ms += (time.perf_counter_ns() - cpu_begin) / 1e6
torch.cuda.synchronize()
device_ms = sum(begin.elapsed_time(end) for begin, end in zip(evt_begin, evt_end))
return device_ms / MEASURE_ITERS, cpu_ms / MEASURE_ITERS
for s, d in zip(src, dst):
assert not torch.allclose(s, d)
fn, multiplier = functools.partial(torch._foreach_copy_, dst, src), 2
for _ in range(100):
fn() # warmup
for s, d in zip(src, dst):
assert torch.allclose(s, d)
device_ms, cpu_ms = benchmark_time(fn)
print(f"device ms: {device_ms:.03f}, cpu ms: {cpu_ms:.03f}")
print(f"memory bandwidth: {gb / device_ms * 1000 * multiplier:.03f} GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment