Skip to content

Instantly share code, notes, and snippets.

View jagadish-amd's full-sized avatar

Jagadish Krishnamoorthy jagadish-amd

View GitHub Profile
@jagadish-amd
jagadish-amd / gist:d057cfb4c8167d8ae2449a59c460d3a3
Created May 15, 2026 07:32
TorchAO patch to implement swizzle on mx fp4 - gfx 950
diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py
index 01d4577b4..e0f058a6a 100644
--- a/benchmarks/float8/bench_matmul.py
+++ b/benchmarks/float8/bench_matmul.py
@@ -138,13 +138,9 @@ def run(
scale_a = to_blocked(scale_a)
scale_b = to_blocked(scale_b)
elif recipe == "mxfp4_cutlass":
- # Use the blockwise scales from to_mx
- if is_ROCM():
@jagadish-amd
jagadish-amd / gist:050687591d89524edab3ca30f7b74174
Created May 15, 2026 07:27
Pytorch patch to implement swizzle on mx fp4 - gfx 950
diff --git a/aten/src/ATen/BlasBackend.h b/aten/src/ATen/BlasBackend.h
index 03b00cc2156..ebcb591eddc 100644
--- a/aten/src/ATen/BlasBackend.h
+++ b/aten/src/ATen/BlasBackend.h
@@ -37,6 +37,8 @@ enum class ScalingType : std::uint8_t {
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
+ // ROCm gfx950: hipBLASLt Block_32_UE8M0_32_8_EXT pre-swizzled E8M0 scales (see HIPBLASLT_MATMUL_MATRIX_SCALE_BLK32_UE8M0_32_8_EXT)
+ BlockWiseBlk32Ue8m0_32_8_EXT,
fast_accum name M K N ref_pct_top_peak pct_top_peak ref_time_s time_s fp8_speedup
0 True 0 1024 1024 1024 0.015986 0.010914 0.000058 0.000043 1.365439
1 True 1 1536 1536 1536 0.047957 0.018315 0.000066 0.000086 0.763805
2 True 2 2048 2048 2048 0.075174 0.036426 0.000099 0.000103 0.969102
3 True 3 3072 3072 3072 0.098859 0.073937 0.000255 0.000170 1.495804
4 True 4 4096 4096 4096 0.127724 0.104420 0.000468 0.000286 1.635096
5 True 5 6144 6144 6144 0.118588 0.087502 0.001701 0.001152 1.475736
6 True 6 8192 8192 8192 0.187609 0.071746 0.002548 0.003332 0.764848
7 True 7 12288 12288 12288 0.153743 0.056482 0.010494 0.014283 0.734754
8 True 8 16384 16384 16384