Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active July 27, 2023 18:18
  • Star 27 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter}us"
if display:
print(res)
return res
def f1(a, b, c, d):
a = a.relu()
b = b.tanh()
e = a * b
f = (c + 2).cos()
return (e + f) * d
inp = [torch.randn(2**24, device='cuda') for _ in range(4)]
f = f1
nf = torch.compile(f)
bench(lambda: f(*inp), name="eager")
bench(lambda: nf(*inp), name="PT 2.0")
import torch
from torch.nn import *
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
import time
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.2f}us"
if display:
print(res)
return res
import torchvision.models as models
mod = models.resnet18().eval().cuda()
opt_mod = torch.compile(mod, mode="reduce-overhead")
inp = torch.randn(1, 3, 224, 224).cuda()
with torch.no_grad():
# Eager: 1938.18us
bench(lambda: mod(inp), "Eager")
# torch.compile (default): 953.96us
# torch.compile (reduce-overhead): 744.02us
bench(lambda: opt_mod(inp), "torch.compile (reduce-overhead)")
import torch
from triton.testing import do_bench
def get_flops(N, get_kernels=False):
A = torch.randn(N, N, device='cuda', dtype=torch.float16)
B = torch.randn(N, N, device='cuda', dtype=torch.float16)
def f():
return torch.mm(A, B)
if get_kernels:
with torch.profiler.profile() as prof:
f()
for e in prof.events():
if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
print(f"{N}: {e.name}")
timer = e.cuda_time/1e3
timer = do_bench(f)
iters_per_second = 1e3/timer
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
flops_achieved = iters_per_second * flops/1e12
print(f"{N}: {flops_achieved:.2f}TF/s")
for N in range(1, 4096):
get_flops(N)
import torch
torch.set_float32_matmul_precision('high')
import torch._inductor.config
torch._inductor.config.debug = True
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
import time
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.3f}us"
if display:
print(res)
return res
def get_bandwidth(name, f):
iters_per_second = 1e6/bench(f, display=False)
bytes_accessed = N**2*4*3
print(f"{name}: {iters_per_second * bytes_accessed/1e9:.2f}GB")
N = 2**14
def f(a, b):
return a + b
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
# eager: 1389.84GB
get_bandwidth("eager", lambda: f(A, B))
# torch.compile: 1388.19GB
get_bandwidth("torch.compile", lambda: torch.compile(f)(A, B))
def f2(a, b):
return a + b.t()
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
# eager: 904.01GB
get_bandwidth("eager", lambda: f2(A, B))
# torch.compile: 1334.89GB
get_bandwidth("torch.compile", lambda: torch.compile(f2)(A, B))
import torch
from triton.testing import do_bench
def get_flops(N, get_kernels=False):
A = torch.randn(N, N, device='cuda', dtype=torch.float16)
B = torch.randn(N, N, device='cuda', dtype=torch.float16)
def f():
return torch.mm(A, B)
if get_kernels:
with torch.profiler.profile() as prof:
f()
for e in prof.events():
if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
print(f"{N}: {e.name}")
timer = e.cuda_time/1e3
timer = do_bench(f)
iters_per_second = 1e3/timer
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
flops_achieved = iters_per_second * flops/1e12
print(f"{N}: {flops_achieved:.2f}TF/s")
for N in range(1, 4096):
get_flops(N)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment