Created
January 2, 2024 08:32
-
-
Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Demonstrating the 2x FLOPs in gamer GPUs when FP16 accumulators are used.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
savepath = Path('mm') | |
savepath.mkdir(exist_ok=True) | |
import torch | |
import triton | |
from triton.ops.matmul import matmul as triton_matmul | |
#matmul = lambda a,b: _matmul.forward(a,b, acc_dtype=torch.float16, allow_tf32=True, output_dtype=torch.float16) # nightly | |
matmul = lambda a,b: triton_matmul(a,b, torch.float16) # stable | |
torch.manual_seed(0) | |
for size in (512,4096): | |
a,b = [torch.randn((size,size), device='cuda', dtype=torch.float16) for _ in 'ab'] | |
print(f"triton_output={matmul(a,b)}") | |
print(f"torch_output={torch.matmul(a,b)}") | |
# stolen from https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=['M', 'N', 'K'], | |
x_vals=[128 * i for i in range(2, 33)], | |
line_arg='provider', | |
line_vals=['cublas', 'triton'], | |
line_names=["cuBLAS", "Triton"], | |
styles=[('green', '-'), ('blue', '-')], | |
ylabel="TFLOPS", | |
plot_name="Matmul perf (3090)", # NOTE: This will *not* add a graph plot header unless you edit triton/testing.py | |
args={}, | |
)) | |
def benchmark(M, N, K, provider): | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
quantiles = [0.5, 0.2, 0.8] | |
f = torch.matmul if provider == 'cublas' else matmul # pytorch defers to cuBLAS. | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: f(a, b), quantiles=quantiles) | |
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) | |
t = perf(ms), perf(max_ms), perf(min_ms) | |
print(provider, t) | |
return t | |
benchmark.run(show_plots=True, print_data=True, save_path=str(savepath)) | |
''' | |
To run this file, you should either: | |
* remove the `prune_configs_by` key in triton/ops/matmul.py (which prunes kernels that are useful in larger blocks early on) | |
* import the matmul impl from the matmul tutorial in https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py instead. | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It is currently impractical to just directly import triton.ops.matmul into existing torch code for a variety of reasons, including:
@
op used) to fuse layers and etc.