Skip to content

Instantly share code, notes, and snippets.

@152334H
Created January 2, 2024 08:32
Show Gist options
  • Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Demonstrating the 2x FLOPs in gamer GPUs when FP16 accumulators are used.
from pathlib import Path
savepath = Path('mm')
savepath.mkdir(exist_ok=True)
import torch
import triton
from triton.ops.matmul import matmul as triton_matmul
#matmul = lambda a,b: _matmul.forward(a,b, acc_dtype=torch.float16, allow_tf32=True, output_dtype=torch.float16) # nightly
matmul = lambda a,b: triton_matmul(a,b, torch.float16) # stable
torch.manual_seed(0)
for size in (512,4096):
a,b = [torch.randn((size,size), device='cuda', dtype=torch.float16) for _ in 'ab']
print(f"triton_output={matmul(a,b)}")
print(f"torch_output={torch.matmul(a,b)}")
# stolen from https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'],
x_vals=[128 * i for i in range(2, 33)],
line_arg='provider',
line_vals=['cublas', 'triton'],
line_names=["cuBLAS", "Triton"],
styles=[('green', '-'), ('blue', '-')],
ylabel="TFLOPS",
plot_name="Matmul perf (3090)", # NOTE: This will *not* add a graph plot header unless you edit triton/testing.py
args={},
))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
f = torch.matmul if provider == 'cublas' else matmul # pytorch defers to cuBLAS.
ms, min_ms, max_ms = triton.testing.do_bench(lambda: f(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
t = perf(ms), perf(max_ms), perf(min_ms)
print(provider, t)
return t
benchmark.run(show_plots=True, print_data=True, save_path=str(savepath))
'''
To run this file, you should either:
* remove the `prune_configs_by` key in triton/ops/matmul.py (which prunes kernels that are useful in larger blocks early on)
* import the matmul impl from the matmul tutorial in https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py instead.
'''
@152334H
Copy link
Author

152334H commented Jan 2, 2024

image

@152334H
Copy link
Author

152334H commented Jan 2, 2024

All tests done on torch==2.1.1+cu121, triton==2.1.0. GPU used is a Zotac 3090 under ~no load (some minimal mem allocated to window manager)

@152334H
Copy link
Author

152334H commented Jan 2, 2024

It is currently impractical to just directly import triton.ops.matmul into existing torch code for a variety of reasons, including:

  1. autotuner being extremely slow on launch
  2. breaking internal torch optimizations (which presumably rely on the exact @ op used) to fuse layers and etc.
  3. autotuner recompiling (or something?) every time a different shaped input is used. Particularly bad for fine-tuning without a global fixed padding length.
  4. will not work for anything other than fp16 (bf16 is a no-go)
  5. will most likely be too imprecise for training out-of-the-box. Less certain about this part, considering fp8...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment