Skip to content

Instantly share code, notes, and snippets.

Forked from malfet/
Created February 16, 2024 00:27
Show Gist options
  • Save stas00/b7e8da20ff999de0eb74f022ccf4fd00 to your computer and use it in GitHub Desktop.
Save stas00/b7e8da20ff999de0eb74f022ccf4fd00 to your computer and use it in GitHub Desktop.
Measure performance difference of `` vs `torch.bmm`
# Benchmark relative performance of and torch.bmm with single batch
import torch
import time
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float:
if use_kineto:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
return sum([e.cuda_time for e in p.key_averages()])
for _ in range(warmup):
begin = time.time()
for _ in range(cycles):
dt = (time.time() - begin)
dt_us = int(dt * 1000000) / cycles
return dt_us
if __name__ == "__main__":
print("torch: ", torch.__version__, " device: ", torch.cuda.get_device_name(0))
msizes = [(1, 1, 4096), (1, 1, 65536), (129, 129, 129), (257, 257, 257), (128, 257, 512), (16385, 5, 16385)]
msizes = [(1, 1, 2**x) for x in range(12, 18)]
msizes += [(2**x, 2**x, 2**x) for x in range(7, 12)]
msizes += [(2**x+1, 2**x-1, 2**x+1) for x in range(7, 12)]
msizes += [(2**x+1, 3, 2**x+1) for x in range(12, 17)]
msizes += [(2**x+1, 5, 2**x+1) for x in range(12, 17)]
msizes += [(2**x+1, 7, 2**x+1) for x in range(12, 17)]
print("| Shape | bmm_time | mm_time | slow down (%) |")
print("| -------------- | --------- | --------- | ------------- |")
for (m, n, k) in msizes:
a = torch.rand((m, k), device='cuda')
b = torch.rand((k, n), device='cuda')
bmm_time = benchmark_fn(torch.bmm, (a.unsqueeze(0), b.unsqueeze(0)))
mm_time = benchmark_fn(, (a, b))
print(f"| {shape_str :^14} | {bmm_time :^9.2f} | {mm_time :^9.2f} | {100.0*(bmm_time-mm_time)/mm_time :^13.2f} |")
assert torch.allclose(torch.bmm(a.unsqueeze(0), b.unsqueeze(0)).squeeze(0),, b))
# Running above script on A100 with torch-2.1.1+cu118 following output is produced
# torch: 2.1.1+cu118 device: NVIDIA A100-SXM4-40GB
# | Shape | bmm_time | mm_time | slow down (%) |
# | -------------- | --------- | --------- | ------------- |
# | 1x1x4096 | 12.38 | 11.96 | 3.48 |
# | 1x1x8192 | 12.26 | 11.84 | 3.55 |
# | 1x1x16384 | 11.81 | 11.66 | 1.29 |
# | 1x1x32768 | 12.00 | 11.81 | 1.61 |
# | 1x1x65536 | 14.82 | 15.05 | -1.48 |
# | 1x1x131072 | 12.02 | 11.77 | 2.15 |
# | 128x128x128 | 9.47 | 9.69 | -2.24 |
# | 256x256x256 | 12.66 | 12.60 | 0.50 |
# | 512x512x512 | 27.34 | 27.31 | 0.10 |
# | 1024x1024x1024 | 129.59 | 129.48 | 0.08 |
# | 2048x2048x2048 | 973.63 | 973.04 | 0.06 |
# | 129x127x129 | 9.56 | 8.97 | 6.62 |
# | 257x255x257 | 12.85 | 12.78 | 0.52 |
# | 513x511x513 | 28.99 | 28.98 | 0.05 |
# | 1025x1023x1025 | 137.92 | 137.76 | 0.11 |
# | 2049x2047x2049 | 982.34 | 982.32 | 0.00 |
# | 4097x3x4097 | 86.94 | 86.91 | 0.03 |
# | 8193x3x8193 | 384.38 | 384.54 | -0.04 |
# | 16385x3x16385 | 1106.25 | 1107.35 | -0.10 |
# | 32769x3x32769 | 4736.79 | 4737.19 | -0.01 |
# | 65537x3x65537 | 17368.65 | 17371.21 | -0.01 |
# | 4097x5x4097 | 87.50 | 87.49 | 0.01 |
# | 8193x5x8193 | 302.27 | 302.29 | -0.00 |
# | 16385x5x16385 | 1107.69 | 1107.65 | 0.00 |
# | 32769x5x32769 | 4743.02 | 4743.13 | -0.00 |
# | 65537x5x65537 | 17393.08 | 17392.32 | 0.00 |
# | 4097x7x4097 | 87.58 | 87.60 | -0.02 |
# | 8193x7x8193 | 302.42 | 302.45 | -0.01 |
# | 16385x7x16385 | 1106.55 | 1107.34 | -0.07 |
# | 32769x7x32769 | 4746.99 | 4746.58 | 0.01 |
# | 65537x7x65537 | 17406.08 | 17424.31 | -0.10 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment