Skip to content

Instantly share code, notes, and snippets.

Created January 2, 2024 08:32
Show Gist options
  • Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Demonstrating the 2x FLOPs in gamer GPUs when FP16 accumulators are used.
from pathlib import Path
savepath = Path('mm')
import torch
import triton
from triton.ops.matmul import matmul as triton_matmul
#matmul = lambda a,b: _matmul.forward(a,b, acc_dtype=torch.float16, allow_tf32=True, output_dtype=torch.float16) # nightly
matmul = lambda a,b: triton_matmul(a,b, torch.float16) # stable
for size in (512,4096):
a,b = [torch.randn((size,size), device='cuda', dtype=torch.float16) for _ in 'ab']
# stolen from
x_names=['M', 'N', 'K'],
x_vals=[128 * i for i in range(2, 33)],
line_vals=['cublas', 'triton'],
line_names=["cuBLAS", "Triton"],
styles=[('green', '-'), ('blue', '-')],
plot_name="Matmul perf (3090)", # NOTE: This will *not* add a graph plot header unless you edit triton/
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
f = torch.matmul if provider == 'cublas' else matmul # pytorch defers to cuBLAS.
ms, min_ms, max_ms = triton.testing.do_bench(lambda: f(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
t = perf(ms), perf(max_ms), perf(min_ms)
print(provider, t)
return t, print_data=True, save_path=str(savepath))
To run this file, you should either:
* remove the `prune_configs_by` key in triton/ops/ (which prunes kernels that are useful in larger blocks early on)
* import the matmul impl from the matmul tutorial in instead.
Copy link

152334H commented Jan 2, 2024

It is currently impractical to just directly import triton.ops.matmul into existing torch code for a variety of reasons, including:

  1. autotuner being extremely slow on launch
  2. breaking internal torch optimizations (which presumably rely on the exact @ op used) to fuse layers and etc.
  3. autotuner recompiling (or something?) every time a different shaped input is used. Particularly bad for fine-tuning without a global fixed padding length.
  4. will not work for anything other than fp16 (bf16 is a no-go)
  5. will most likely be too imprecise for training out-of-the-box. Less certain about this part, considering fp8...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment