Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active April 8, 2024 04:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chillee/abc38703f88fcb64683b6ccb0ae9d8ba to your computer and use it in GitHub Desktop.
Save Chillee/abc38703f88fcb64683b6ccb0ae9d8ba to your computer and use it in GitHub Desktop.
What Shapes Do Matrix Multiplications Like?
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(M, K, dtype=torch.bfloat16)
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
M=108 * 256
N=3486
K=4679
A = torch.zeros(M, K, dtype=torch.bfloat16)
B = torch.zeros(N, K, dtype=torch.bfloat16).t()
print(f"M={M}, K={K}, N={N}")
torch.mm(A, B)
# ncu --metrics launch__waves_per_multiprocessor python Q2.py
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(K, M, dtype=torch.bfloat16).t()
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
@Chillee
Copy link
Author

Chillee commented Apr 8, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment