Skip to content

Instantly share code, notes, and snippets.

@georg-bn
Created May 1, 2024 08:36
Show Gist options
  • Save georg-bn/438076758c30810e6c82dcef5a76a750 to your computer and use it in GitHub Desktop.
Save georg-bn/438076758c30810e6c82dcef5a76a750 to your computer and use it in GitHub Desktop.
Profiling matrix multiplication with small matrices in Pytorch
import torch
from torch.profiler import profile, record_function, ProfilerActivity
def fast_bmm(a, b):
return (a.unsqueeze(-1) * b.unsqueeze(-3)).sum(-2)
fast_bmm_compiled = torch.compile(fast_bmm)
def run_comparison(batch_size, matrix_dim):
B, D = batch_size, matrix_dim
# run functions
fast_bmm_compiled(torch.randn([B, D, D]).cuda(), torch.randn([B, D, D]).cuda())
fast_bmm(torch.randn([B, D, D]).cuda(), torch.randn([B, D, D]).cuda())
mat0 = torch.randn([B, D, D]).cuda()
mat1 = torch.randn([B, D, D]).cuda()
assert torch.allclose(
torch.bmm(mat0, mat1),
fast_bmm_compiled(mat0, mat1),
atol=1e-4,
rtol=1e-4,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
torch.bmm(mat0, mat1) # warm up?
with record_function("fast bmm"):
fast_bmm(mat0, mat1)
with record_function("compiled fast bmm"):
fast_bmm_compiled(mat0, mat1)
with record_function("bmm"):
torch.bmm(mat0, mat1)
return prof
if __name__ == "__main__":
B = 2**16
dimensions = range(2, 20)
results = {"bmm": [], "fast bmm": [], "compiled fast bmm": []}
for D in dimensions:
prof = run_comparison(B, D)
key_av = prof.key_averages()
for struct in key_av:
if struct.key in results:
results[struct.key].append(struct.cuda_time_total)
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
plt.semilogy(
dimensions, results["bmm"],
dimensions, results["fast bmm"],
dimensions, results["compiled fast bmm"],
)
plt.legend(["torch.bmm", "fast_bmm", "torch.compile(fast_bmm)"])
plt.title(f"Profiling batched square matrix multiplication [{B}, D, D] @ [{B}, D, D]")
plt.xlabel("Dimension (D)")
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.ylabel("Total CUDA time [us]")
plt.savefig("profile.png", dpi=200, bbox_inches="tight")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment