- https://jaglinux.github.io
- in/jagadish-krishnamoorthy
- @JAGsPOSTs
- https://github.com/jaglinux
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py | |
| index 01d4577b4..e0f058a6a 100644 | |
| --- a/benchmarks/float8/bench_matmul.py | |
| +++ b/benchmarks/float8/bench_matmul.py | |
| @@ -138,13 +138,9 @@ def run( | |
| scale_a = to_blocked(scale_a) | |
| scale_b = to_blocked(scale_b) | |
| elif recipe == "mxfp4_cutlass": | |
| - # Use the blockwise scales from to_mx | |
| - if is_ROCM(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/aten/src/ATen/BlasBackend.h b/aten/src/ATen/BlasBackend.h | |
| index 03b00cc2156..ebcb591eddc 100644 | |
| --- a/aten/src/ATen/BlasBackend.h | |
| +++ b/aten/src/ATen/BlasBackend.h | |
| @@ -37,6 +37,8 @@ enum class ScalingType : std::uint8_t { | |
| BlockWise1x32, // fp8_e8m0fnu scales | |
| BlockWise1x128, // fp32 scales | |
| BlockWise128x128, // fp32 scales | |
| + // ROCm gfx950: hipBLASLt Block_32_UE8M0_32_8_EXT pre-swizzled E8M0 scales (see HIPBLASLT_MATMUL_MATRIX_SCALE_BLK32_UE8M0_32_8_EXT) | |
| + BlockWiseBlk32Ue8m0_32_8_EXT, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| fast_accum name M K N ref_pct_top_peak pct_top_peak ref_time_s time_s fp8_speedup | |
| 0 True 0 1024 1024 1024 0.015986 0.010914 0.000058 0.000043 1.365439 | |
| 1 True 1 1536 1536 1536 0.047957 0.018315 0.000066 0.000086 0.763805 | |
| 2 True 2 2048 2048 2048 0.075174 0.036426 0.000099 0.000103 0.969102 | |
| 3 True 3 3072 3072 3072 0.098859 0.073937 0.000255 0.000170 1.495804 | |
| 4 True 4 4096 4096 4096 0.127724 0.104420 0.000468 0.000286 1.635096 | |
| 5 True 5 6144 6144 6144 0.118588 0.087502 0.001701 0.001152 1.475736 | |
| 6 True 6 8192 8192 8192 0.187609 0.071746 0.002548 0.003332 0.764848 | |
| 7 True 7 12288 12288 12288 0.153743 0.056482 0.010494 0.014283 0.734754 | |
| 8 True 8 16384 16384 16384 |