Last active
July 22, 2024 15:46
-
-
Save philipturner/84f613a5cc745460a914d2c6ad226131 to your computer and use it in GitHub Desktop.
Single shader source that supports every hardware architecture, problem size, and precision
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// main.swift | |
// UnifiedGEMMKernel | |
// | |
// Created by Philip Turner on 5/29/24. | |
// | |
import Metal | |
#if os(macOS) | |
import IOKit | |
#endif | |
// Single shader source that supports every hardware architecture, problem | |
// size, and precision. | |
// | |
// ========================================================================== // | |
// Introduction | |
// ========================================================================== // | |
// | |
// There should not be a tradeoff between: | |
// - Reaching theoretical maximum ALU performance (within a factor of 1.01) | |
// - Reaching theoretical minimum compilation latency (provided you have an | |
// engine for caching shader variants) | |
// - Having short, legible, and portable source code | |
// - Supporting every GPU generation and data type | |
// | |
// The existing solutions out there compromise on one of these. For example, | |
// MPS and Mojo are closed source, meaning the code is not legible. Other ML | |
// frameworks are ergonomic and easy to use, but sacrifice performance. | |
// Performance is most often neglected on older chips (M1, M2) or edge cases | |
// (when matrix dimensions are not divisible by block size). | |
// | |
// Recently, I found a way to access SIMD async copy instructions from the JIT | |
// compiler\*. This removes the need to use Xcode 14.2 in the shader compilation | |
// process. Instead of encoding all arguments in function constants, some of | |
// them can be encoded directly into a JIT-compiled shader source. This | |
// freedom allows for simpler shader code and a simpler interface for running | |
// the kernel. For example, the client no longer needs to allocate threadgroup | |
// memory at runtime. | |
// | |
// > \*This and some additional context is organized at: | |
// > https://gist.github.com/philipturner/0dd518be6705544778474c4e9f8fd68d | |
// | |
// The code should be robust against worst-case situations. Imagine a PyTorch | |
// workflow where someone explores a hyperparameter space. They try changing | |
// the size of an MLP from 128 neurons, to 129 neurons, 130, etc. Each call | |
// into a matrix multiplication has a different problem size. If a new kernel | |
// was cached/compiled from scratch for each call, the latency would become a | |
// bottleneck. Yet, knowing the problem size beforehand is critical to reaching | |
// maximum performance (and outperforming MPS). The code should meet the | |
// standards of something that replaces MPS across thousands of end-user | |
// workflows. | |
// | |
// The first attempt will not be perfect. I will learn things, try again with | |
// new implementations written from scratch. Continue distilling, simplifying, | |
// making the code more robust. Compare against alternatives (MPS, MLX) with a | |
// data-driven, evidence-driven approach. | |
// | |
// ========================================================================== // | |
// Methods | |
// ========================================================================== // | |
// | |
// The following design specifications were drafted for the source file. | |
// | |
// Anything with a fixed number of options that would require adding control | |
// flow blocks to the shader. | |
// - Device architecture | |
// - Blocking setup, threadgroup memory allocation | |
// - Operand precisions, accumulator precision | |
// Convention: Injected into shader source. | |
// | |
// Remaining parts that must be known ahead of time for reasonable performance. | |
// - Remainder of the integer division: problem size / block size | |
// - Transpose state of operands | |
// Convention: Choice between injection into shader source, or function | |
// constants. | |
// | |
// Parts where inlining a constant into assembly would maximize performance. | |
// - Actual problem size | |
// Convention: Choice between injection into shader source, function constants, | |
// or runtime data structure. | |
// | |
// Draft a kernel that does all of these things. It doesn't need optimal | |
// performance; just needs to implement them correctly. | |
// | |
// ========================================================================== // | |
// Lab Notes | |
// ========================================================================== // | |
// | |
// First issue: | |
// - The optimized BF16 decoding function from a previous experiment does not | |
// support transposed representations in memory. | |
// - It also didn't support loading/storing from threadgroup memory. | |
// - The previous experiment didn't implement all the components of the BF16 -> | |
// FP32 decoding optimization for M1. | |
// | |
// Second issue: | |
// - simdgroup_matrix_storage.load/store doesn't support unaligned inputs. | |
// - Previous design delegated that to the SIMD async copy unit in hardware. | |
// - Likely need a better alternative on M3/M4, where async copy is emulated | |
// and hence very slow. | |
// | |
// Solution: | |
// - Three code paths for loading/storing from memory, until I prove the Apple | |
// GPU 'device_load' instruction natively supports alignments smaller than | |
// the vector size (e.g. alignment is the scalar size). | |
// - Transposed (2 instructions) | |
// - Untransposed, not aligned to multiple of 2 (2 instructions) | |
// - Untransposed, aligned to multiple of 2 (1 instruction) | |
// - Use code generation to spawn the header with compact Swift code. | |
// | |
// Third issue: | |
// - While trying to get rid of async copies, I kept finding edge cases. For | |
// example, when the matrix dimension is an odd number, and RAM accesses go | |
// out of bounds. Fixing these requires either incurring immense overhead at | |
// matrix edges, or baking if/else statements into source code. | |
// | |
// I found something interesting. If I avoid async copies for most of the inner | |
// loop iterations, I can get decent performance on M4. This is a slightly | |
// modified version of the M1 kernel, which divides the "k" accumulator into | |
// two sections. The first section reads the inputs directly from device memory. | |
// The last section reads from threadgroup memory. | |
// | |
// This modified kernel frequently causes IOCommandBuffer errors on M1. I need | |
// to understand why that is happening. | |
// | |
// ========================================================================== // | |
// Tuning Performance on M4 | |
// ========================================================================== // | |
// | |
// Avoiding regressions on M1 Max: | |
// | |
// The kernel must achieve 8100 GFLOPS @ 1535x1535, 48x48x24. | |
// The kernel must achieve 8150 GFLOPS @ 1536x1536, 48x48x24. | |
// The kernel must achieve 7530 GFLOPS @ 1537x1537, 48x48x24. | |
// | |
// Reference statistics for M1 Max: | |
// | |
// - problemSize = 256 | 913 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2931 GFLOPS (32x32x8) | |
// - problemSize = 512 | 5342 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5463 GFLOPS (32x32x8) 6440 GFLOPS (async copy) | |
// - problemSize = 768 | 6160 GFLOPS (48x48x8) 7017 GFLOPS (async copy) | |
// - problemSize = 896 | 6643 GFLOPS (48x48x8) 7136 GFLOPS (async copy) | |
// - problemSize = 1024 | 7596 GFLOPS (48x48x8) 6966 GFLOPS (async copy) | |
// - problemSize = 1152 | 7676 GFLOPS (48x48x8) 8144 GFLOPS (async copy) | |
// - problemSize = 1280 | 7712 GFLOPS (48x48x8) 7813 GFLOPS (async copy) | |
// - problemSize = 1408 | 7747 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 8392 GFLOPS (48x48x8) | |
// | |
// Performance target on M4: | |
// | |
// - problemSize = 256 | 1195 GFLOPS (32x32x8) 590 GFLOPS (MPS) | |
// - problemSize = 384 | 1729 GFLOPS (32x32x8) 1105 GFLOPS (MPS) | |
// - problemSize = 512 | 2549 GFLOPS (32x32x8) 2051 GFLOPS (MPS) | |
// - problemSize = 640 | 2983 GFLOPS (32x32x8) 3028 GFLOPS (MPS) | |
// - problemSize = 768 | 3036 GFLOPS (32x32x8) 3087 GFLOPS (MPS) | |
// - problemSize = 896 | 3044 GFLOPS (32x32x8) 3086 GFLOPS (MPS) | |
// - problemSize = 1024 | 3074 GFLOPS (32x32x8) 3125 GFLOPS (MPS) | |
// - problemSize = 1152 | 3123 GFLOPS (32x32x8) 3152 GFLOPS (MPS) | |
// - problemSize = 1280 | 3134 GFLOPS (32x32x8) 3134 GFLOPS (MPS) | |
// - problemSize = 1408 | 3167 GFLOPS (32x32x8) 3150 GFLOPS (MPS) | |
// - problemSize = 1536 | 3174 GFLOPS (32x32x8) 3129 GFLOPS (MPS) | |
// | |
// Performance deterioration for odd problem sizes: | |
// M1 Max (32x32, async copy), M4 (32x32, no async copy) | |
// | |
// - problemSize = 254 | 1888 GFLOPS 950 GFLOPS | |
// - problemSize = 255 | 1950 GFLOPS 971 GFLOPS | |
// - problemSize = 256 | 2087 GFLOPS 1210 GFLOPS | |
// - problemSize = 257 | 1744 GFLOPS 907 GFLOPS | |
// - problemSize = 258 | 1754 GFLOPS 921 GFLOPS | |
// | |
// - problemSize = 510 | 5296 GFLOPS 2614 GFLOPS | |
// - problemSize = 511 | 5266 GFLOPS 2624 GFLOPS | |
// - problemSize = 512 | 5390 GFLOPS 2765 GFLOPS | |
// - problemSize = 513 | 5180 GFLOPS 2365 GFLOPS | |
// - problemSize = 514 | 5208 GFLOPS 2377 GFLOPS | |
// | |
// - problemSize = 1022 | 5989 GFLOPS 3054 GFLOPS | |
// - problemSize = 1023 | 5905 GFLOPS 3059 GFLOPS | |
// - problemSize = 1024 | 7164 GFLOPS 3051 GFLOPS | |
// - problemSize = 1025 | 5618 GFLOPS 2880 GFLOPS | |
// - problemSize = 1026 | 5770 GFLOPS 2905 GFLOPS | |
// | |
// Overall scaling for aligned problem sizes: | |
// M4 (32x32, no async copy) vs M4 (MPS) | |
// | |
// - problemSize = 256 | 1205 GFLOPS vs 590 GFLOPS (MPS) | |
// - problemSize = 384 | 1711 GFLOPS vs 1105 GFLOPS (MPS) | |
// - problemSize = 512 | 2747 GFLOPS vs 2051 GFLOPS (MPS) | |
// - problemSize = 640 | 2939 GFLOPS vs 3028 GFLOPS (MPS) | |
// - problemSize = 768 | 3010 GFLOPS vs 3087 GFLOPS (MPS) | |
// - problemSize = 896 | 3024 GFLOPS vs 3086 GFLOPS (MPS) | |
// - problemSize = 1024 | 3040 GFLOPS vs 3125 GFLOPS (MPS) | |
// - problemSize = 1152 | 3101 GFLOPS vs 3152 GFLOPS (MPS) | |
// - problemSize = 1280 | 3101 GFLOPS vs 3134 GFLOPS (MPS) | |
// - problemSize = 1408 | 3130 GFLOPS vs 3150 GFLOPS (MPS) | |
// - problemSize = 1536 | 3140 GFLOPS vs 3129 GFLOPS (MPS) | |
// | |
// The above results were without some potential optimizations. For example, | |
// fusing device_load instructions for consecutive matrix elements when the | |
// matrix dimension is not divisible by 2. In addition, eliding the async | |
// copies when writing the C matrix to memory on M4. | |
// | |
// ========================================================================== // | |
// Tuning Precisions | |
// ========================================================================== // | |
// | |
// ## M1 Max | |
// | |
// Configuration: | |
// - maximum of 32x32x32 and 48x48x24/32 | |
// - inputs are not transposed | |
// | |
// memA | memB | memC | regA | regB | regC | 512 | 768 | 1024 | 1280 | 1536 | | |
// ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 | 6754 | 7274 | 7604 | 7891 | 8300 | | |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 | 6794 | 7244 | 7578 | 7864 | 8299 | | |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 | 6651 | 7224 | 7630 | 7920 | 8322 | | |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 | 6560 | 7231 | 7632 | 7924 | 8318 | | |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 | 6118 | 6624 | 6912 | 7381 | 7590 | | |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 | 6398 | 7159 | 7031 | 7632 | 8232 | | |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 | 5556 | 6779 | 6529 | 7297 | 7680 | | |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 | 5223 | 7149 | 6898 | 7753 | 8112 | | |
// | |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | 6796 | 7879 | 8331 | 8396 | 8492 | | |
// FP16 | FP16 | FP32 | FP32 | FP32 | FP32 | 6727 | 7848 | 8306 | 8396 | 8490 | | |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 | 6156 | 6935 | 7221 | 7535 | 7841 | | |
// FP16 | BF16 | FP32 | FP16 | FP32 | FP32 | 6304 | 7157 | 7617 | 7894 | 8142 | | |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 | 6347 | 7130 | 7628 | 7904 | 8152 | | |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 | 6246 | 7348 | 7722 | 7758 | 7878 | | |
// BF16 | FP16 | FP32 | FP32 | FP16 | FP32 | 6435 | 7328 | 7324 | 7865 | 8260 | | |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 | 6352 | 7339 | 7312 | 7865 | 8259 | | |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 | 5787 | 6598 | 6993 | 6967 | 7207 | | |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 | 6075 | 6967 | 6955 | 7515 | 7888 | | |
// | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | 7077 | 8535 | 8096 | 8660 | 9136 | | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP32 | 6946 | 8561 | 8322 | 8696 | 9103 | | |
// FP16 | FP16 | FP16 | FP16 | FP32 | FP16 | 6384 | 7742 | 7496 | 7879 | 8254 | | |
// FP16 | FP16 | FP16 | FP32 | FP16 | FP16 | 6350 | 7747 | 7476 | 7875 | 8263 | | |
// FP16 | FP16 | FP16 | FP32 | FP32 | FP32 | 7124 | 8505 | 8330 | 8702 | 9091 | | |
// | |
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 | 5861 | 7356 | 7084 | 7426 | 7720 | | |
// BF16 | BF16 | BF16 | BF16 | FP32 | FP32 | 5676 | 7805 | 7415 | 7724 | 8250 | | |
// BF16 | BF16 | BF16 | FP32 | BF16 | FP32 | 6243 | 6998 | 7031 | 7355 | 7724 | | |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 | 6367 | 7210 | 7086 | 7544 | 7930 | | |
// | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP16 | 5738 | 6450 | 6130 | 6741 | 7312 | | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | 5420 | 7084 | 7171 | 7739 | 8223 | | |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | 6452 | 7173 | 7200 | 7804 | 8243 | | |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | 5368 | 7074 | 7165 | 7740 | 8225 | | |
// | |
// FP16 | BF16 | FP16 | FP16 | FP32 | FP16 | 5908 | 6873 | 7076 | 7417 | 7745 | | |
// FP16 | BF16 | FP16 | FP16 | FP32 | FP32 | 6566 | 7598 | 7617 | 7993 | 8489 | | |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP16 | 5896 | 6873 | 7070 | 7419 | 7739 | | |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP32 | 5891 | 7602 | 7616 | 8011 | 8494 | | |
// FP16 | BF16 | BF16 | FP16 | FP32 | FP32 | 6650 | 7581 | 7572 | 7987 | 8462 | | |
// FP16 | BF16 | BF16 | FP32 | FP32 | FP32 | 6809 | 7534 | 7562 | 8027 | 8467 | | |
// | |
// BF16 | FP16 | FP16 | FP32 | FP16 | FP16 | 5937 | 6923 | 7108 | 7527 | 7930 | | |
// BF16 | FP16 | FP16 | FP32 | FP16 | FP32 | 6414 | 7554 | 7333 | 7981 | 8459 | | |
// BF16 | FP16 | FP16 | FP32 | FP32 | FP16 | 5976 | 6924 | 7083 | 7525 | 7929 | | |
// BF16 | FP16 | FP16 | FP32 | FP32 | FP32 | 5383 | 7559 | 7310 | 8008 | 8456 | | |
// BF16 | FP16 | BF16 | FP32 | FP16 | FP32 | 6575 | 7573 | 7512 | 8028 | 8469 | | |
// BF16 | FP16 | BF16 | FP32 | FP32 | FP32 | 6699 | 7530 | 7525 | 8000 | 8470 | | |
// | |
// BF16 | BF16 | FP16 | FP32 | FP32 | FP16 | 5756 | 6721 | 6780 | 7220 | 7583 | | |
// BF16 | BF16 | FP16 | FP32 | FP32 | FP32 | 6424 | 7206 | 6954 | 7547 | 7923 | | |
// | |
// Optimal register precisions for each memory precision. | |
// | |
// Truth Table: | |
// | |
// memA | memB | memC | regA | regB | regC | | |
// ---- | ---- | ---- | ---- | ---- | ---- | | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | | |
// FP16 | FP16 | BF16 | FP16 | FP16 | FP32 | | |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | | |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP32 | | |
// FP16 | BF16 | BF16 | FP32 | FP32 | FP32 | | |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 | | |
// FP16 | FP32 | FP16 | FP32 | FP32 | FP32 | | |
// FP16 | FP32 | BF16 | FP32 | FP32 | FP32 | | |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 | | |
// | |
// BF16 | FP16 | FP16 | FP32 | FP32 | FP32 | | |
// BF16 | FP16 | BF16 | FP32 | FP32 | FP32 | | |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 | | |
// BF16 | BF16 | FP16 | FP32 | FP32 | FP32 | | |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 | | |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 | | |
// BF16 | FP32 | FP16 | FP32 | FP32 | FP32 | | |
// BF16 | FP32 | BF16 | FP32 | FP32 | FP32 | | |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 | | |
// | |
// FP32 | FP16 | FP16 | FP32 | FP32 | FP32 | | |
// FP32 | FP16 | BF16 | FP32 | FP32 | FP32 | | |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 | | |
// FP32 | BF16 | FP16 | FP32 | FP32 | FP32 | | |
// FP32 | BF16 | BF16 | FP32 | FP32 | FP32 | | |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 | | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | | |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | | |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | | |
// | |
// Optimized form of the logic: | |
// | |
// If memA and memB are FP16, | |
// regA is FP16 | |
// regB is FP16 | |
// else | |
// regA is FP32 | |
// regB is FP32 | |
// If memA, memB, and memC are FP16, | |
// regC is FP16 | |
// else | |
// regC is FP32 | |
// | |
// ## M4 | |
// | |
// Configuration: | |
// - 32x32x8 | |
// - inputs are not transposed | |
// | |
// memA | memB | memC | regA | regB | regC | 512 | 768 | 1024 | 1280 | 1536 | | |
// ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 | 2864 | 3093 | 3128 | 3201 | 3237 | | |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 | 2823 | 3093 | 3128 | 3197 | 3236 | | |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 | 2871 | 3136 | 3180 | 3238 | 3276 | | |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 | 2853 | 3133 | 3180 | 3239 | 3275 | | |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 | 2832 | 3089 | 3122 | 3199 | 3237 | | |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 | 2735 | 2981 | 2999 | 3067 | 3102 | | |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 | 2852 | 3134 | 3189 | 3241 | 3276 | | |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 | 2742 | 3007 | 3053 | 3104 | 3139 | | |
// | |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | 2985 | 3285 | 3342 | 3417 | 3458 | | |
// FP16 | FP16 | FP32 | FP32 | FP32 | FP32 | 2987 | 3285 | 3348 | 3416 | 3458 | | |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 | 3017 | 3290 | 3344 | 3417 | 3460 | | |
// FP16 | BF16 | FP32 | FP16 | FP32 | FP32 | 2861 | 3130 | 3179 | 3244 | 3282 | | |
// FP16 | BF16 | FP32 | FP32 | BF16 | FP32 | 2988 | 3281 | 3349 | 3419 | 3459 | | |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 | 2887 | 3127 | 3177 | 3244 | 3281 | | |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 | 2990 | 3284 | 3339 | 3418 | 3458 | | |
// BF16 | FP16 | FP32 | BF16 | FP32 | FP32 | 2712 | 3277 | 3332 | 3413 | 3459 | | |
// BF16 | FP16 | FP32 | FP32 | FP16 | FP32 | 2836 | 3108 | 3168 | 3221 | 3258 | | |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 | 2855 | 3108 | 3162 | 3220 | 3257 | | |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 | 2978 | 3285 | 3347 | 3417 | 3458 | | |
// BF16 | BF16 | FP32 | BF16 | FP32 | FP32 | 2868 | 3119 | 3182 | 3245 | 3281 | | |
// BF16 | BF16 | FP32 | FP32 | BF16 | FP32 | 2850 | 3114 | 3164 | 3218 | 3257 | | |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 | 2707 | 2943 | 2991 | 3038 | 3070 | | |
// | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | 3069 | 3334 | 3410 | 3471 | 3512 | | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP32 | 3025 | 3264 | 3318 | 3394 | 3430 | | |
// FP16 | FP16 | FP16 | FP16 | FP32 | FP16 | 2708 | 2932 | 2990 | 3039 | 3072 | | |
// FP16 | FP16 | FP16 | FP32 | FP16 | FP16 | 2714 | 2928 | 2987 | 3041 | 3074 | | |
// FP16 | FP16 | FP16 | FP32 | FP32 | FP32 | 3016 | 3260 | 3315 | 3392 | 3432 | | |
// | |
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 | 3025 | 3280 | 3324 | 3409 | 3445 | | |
// BF16 | BF16 | BF16 | BF16 | FP32 | FP32 | 2876 | 3111 | 3164 | 3233 | 3268 | | |
// BF16 | BF16 | BF16 | FP32 | BF16 | FP32 | 2857 | 3093 | 3138 | 3206 | 3241 | | |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 | 2707 | 2883 | 2939 | 2982 | 3014 | | |
// | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP16 | 2535 | 2791 | 2769 | 2828 | 2861 | | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | 2781 | 2999 | 3022 | 3077 | 3115 | | |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | 2809 | 3052 | 3090 | 3141 | 3178 | | |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | 2772 | 3017 | 3044 | 3101 | 3141 | | |
// | |
// FP16 | BF16 | FP16 | FP16 | BF16 | FP16 | 2702 | 2937 | 2984 | 3040 | 3072 | | |
// FP16 | BF16 | FP16 | FP16 | BF16 | FP32 | 3010 | 3258 | 3317 | 3392 | 3431 | | |
// FP16 | BF16 | FP16 | FP32 | BF16 | FP16 | 2712 | 2938 | 2990 | 3040 | 3073 | | |
// FP16 | BF16 | FP16 | FP32 | BF16 | FP32 | 2997 | 3260 | 3314 | 3394 | 3431 | | |
// FP16 | BF16 | BF16 | FP16 | BF16 | FP32 | 3028 | 3275 | 3324 | 3403 | 3447 | | |
// FP16 | BF16 | BF16 | FP32 | BF16 | FP32 | 3008 | 3274 | 3322 | 3407 | 3443 | | |
// | |
// BF16 | FP16 | FP16 | BF16 | FP16 | FP16 | 2705 | 2938 | 2987 | 3042 | 3072 | | |
// BF16 | FP16 | FP16 | BF16 | FP16 | FP32 | 3011 | 3255 | 3310 | 3392 | 3433 | | |
// BF16 | FP16 | FP16 | BF16 | FP32 | FP16 | 2698 | 2939 | 2988 | 3039 | 3073 | | |
// BF16 | FP16 | FP16 | BF16 | FP32 | FP32 | 3015 | 3261 | 3313 | 3391 | 3431 | | |
// BF16 | FP16 | BF16 | BF16 | FP16 | FP32 | 3027 | 3276 | 3323 | 3406 | 3445 | | |
// BF16 | FP16 | BF16 | BF16 | FP32 | FP32 | 3040 | 3275 | 3327 | 3407 | 3445 | | |
// | |
// BF16 | BF16 | FP16 | BF16 | BF16 | FP16 | 2709 | 2938 | 2987 | 3038 | 3073 | | |
// BF16 | BF16 | FP16 | BF16 | BF16 | FP32 | 3005 | 3260 | 3320 | 3394 | 3433 | | |
// | |
// Optimal register precisions for each memory precision. | |
// | |
// Truth Table: | |
// | |
// memA | memB | memC | regA | regB | regC | | |
// ---- | ---- | ---- | ---- | ---- | ---- | | |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | | |
// FP16 | FP16 | BF16 | FP16 | FP16 | FP32 | | |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | | |
// FP16 | BF16 | FP16 | FP16 | BF16 | FP32 | | |
// FP16 | BF16 | BF16 | FP16 | BF16 | FP32 | | |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 | | |
// FP16 | FP32 | FP16 | FP16 | FP32 | FP32 | | |
// FP16 | FP32 | BF16 | FP16 | FP32 | FP32 | | |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 | | |
// | |
// BF16 | FP16 | FP16 | BF16 | FP16 | FP32 | | |
// BF16 | FP16 | BF16 | BF16 | FP16 | FP32 | | |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 | | |
// BF16 | BF16 | FP16 | BF16 | BF16 | FP32 | | |
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 | | |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 | | |
// BF16 | FP32 | FP16 | BF16 | FP32 | FP32 | | |
// BF16 | FP32 | BF16 | BF16 | FP32 | FP32 | | |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 | | |
// | |
// FP32 | FP16 | FP16 | FP32 | FP16 | FP32 | | |
// FP32 | FP16 | BF16 | FP32 | FP16 | FP32 | | |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 | | |
// FP32 | BF16 | FP16 | FP32 | BF16 | FP32 | | |
// FP32 | BF16 | BF16 | FP32 | BF16 | FP32 | | |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 | | |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | | |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | | |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | | |
// | |
// Optimized form of the logic: | |
// | |
// regA is identical to memA | |
// regB is identical to memB | |
// If memA, memB, and memC are FP16, | |
// regC is FP16 | |
// else | |
// regC is FP32 | |
// | |
// ## Conclusion | |
// | |
// Use a common set of logic for both architectures: | |
// | |
// ``` | |
// regA is identical to memA | |
// regB is identical to memB | |
// If memA, memB, and memC are FP16, | |
// regC is FP16 | |
// else | |
// regC is FP32 | |
// | |
// If earlier than M3 | |
// If memA is BF16, | |
// regA is FP32 | |
// If memB is BF16, | |
// regB is FP32 | |
// ``` | |
// | |
// ========================================================================== // | |
// Tuning Transposes | |
// ========================================================================== // | |
// | |
// I am considering a simplification to the code, which merges two control | |
// blocks in simdgroup_matrix_storage.load. | |
// | |
// M1 Max, FP32xFP32->FP32 | |
// | |
// Original code: | |
// problemSize = 1023 | A B | 6941 GFLOPS | |
// problemSize = 1023 | A B^T | 6781 GFLOPS | |
// problemSize = 1023 | A^T B | 7082 GFLOPS | |
// problemSize = 1023 | A^T B^T | 6763 GFLOPS | |
// problemSize = 1024 | A B | 6995 GFLOPS | |
// problemSize = 1024 | A B^T | 6901 GFLOPS | |
// problemSize = 1024 | A^T B | 7173 GFLOPS | |
// problemSize = 1024 | A^T B^T | 6926 GFLOPS | |
// problemSize = 1025 | A B | 7035 GFLOPS | |
// problemSize = 1025 | A B^T | 6699 GFLOPS | |
// problemSize = 1025 | A^T B | 7269 GFLOPS | |
// problemSize = 1025 | A^T B^T | 6927 GFLOPS | |
// | |
// With packed_vec for threadgroup loads: | |
// problemSize = 1023 | A B | 6925 GFLOPS | |
// problemSize = 1023 | A B^T | 6786 GFLOPS | |
// problemSize = 1023 | A^T B | 6977 GFLOPS | |
// problemSize = 1023 | A^T B^T | 6770 GFLOPS | |
// problemSize = 1024 | A B | 6979 GFLOPS | |
// problemSize = 1024 | A B^T | 6894 GFLOPS | |
// problemSize = 1024 | A^T B | 7109 GFLOPS | |
// problemSize = 1024 | A^T B^T | 6931 GFLOPS | |
// problemSize = 1025 | A B | 7005 GFLOPS | |
// problemSize = 1025 | A B^T | 6707 GFLOPS | |
// problemSize = 1025 | A^T B | 7225 GFLOPS | |
// problemSize = 1025 | A^T B^T | 6943 GFLOPS | |
// | |
// Significant decrease (-39 GFLOPS) for 1023. | |
// Significant decrease (-29 GFLOPS) for 1024. | |
// Significant decrease (-22 GFLOPS) for 1025. | |
// | |
// Conclusion: | |
// | |
// Originally, I thought the code change was safe. I had gathered evidence, and | |
// saw no detectable performance delta after the change. However, I had failed | |
// to switch a few source lines between being commented out and not. The | |
// corrected test showed a major regression when packed_vec<T, 2> was used for | |
// threadgroup memory accesses with power-2 row sizes. | |
// | |
// For odd numbered row sizes (and by extension, odd-numbered problem sizes in | |
// device memory), there is no change. Still, I am not going to switch any more | |
// lines of code to packed vectors. The evidence is too murky and the only | |
// benefit of the changes would be making code more legible. | |
// | |
// This is a good example of the thought process when writing high-performance | |
// code. You need reliable methods to gather data (without statistical noise), | |
// comparing whether each tiny change increases or decreases performance. The | |
// optimizations that harm performance or provide no net gain, you don't | |
// include. Assume there is always a regression in one of the combinatorial | |
// explosion of problem configurations. Find that regression and prove it | |
// doesn't exist. | |
// | |
// ========================================================================== // | |
// Tuning Async Copies on M3+ | |
// ========================================================================== // | |
// | |
// Tasks: | |
// - In the device loading iterations, some garbage values are being read | |
// from out-of-bounds addresses. Fix this without harming performance. | |
// - Does the fix also remove the IOCommandBuffer errors on M1? | |
// - Try to further minimize the overhead of async copies on M4. | |
// | |
// I removed the alpha and beta constants, which simplifies the shader. Some of | |
// the fixes to the problems above, would require rearchitecting of the code | |
// for handling non-zero beta. Use cases that require stuff like accumulating | |
// into an existing accumulator, would likely require more flexibility than | |
// just a scalar argument in a pre-written kernel (e.g. one would write their | |
// own fused activation shader from scratch). Alpha/beta are technical debt | |
// from the decades-old BLAS interface and out of scope for this reference | |
// implementation. | |
// | |
// Next, I am changing how the code handles matrix edges. | |
// | |
// Correctness tests: | |
// - Precision: FP32xFP32->FP32 | |
// - Problem size: (8, 16, 24, 32, 48, 64, 104, 128)(±1) | |
// - Transpose state: AB, AB^T, A^T B | |
// | |
// Performance tests: | |
// - Avoid any regressions from the following statistics. | |
// - Executed the program 3 times, took the maximum value of every row of | |
// the tables below. This approach minimizes statistical noise. The maximum | |
// performance achievable by two similar programs is more amenable to | |
// rigorous analysis than the average performance. Especially when the | |
// performance delta of some changes is so small, it would be drowned out | |
// by statistical noise. | |
// | |
// M1 Max: | |
// FP32 | problemSize = 1535 | A B | 8170 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7876 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8229 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8085 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7607 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7237 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 9064 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 8978 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9163 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9063 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8558 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8555 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 8261 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 8283 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8338 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8394 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7730 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7893 GFLOPS | |
// | |
// M4: | |
// FP32 | problemSize = 1023 | A B | 3030 GFLOPS | |
// FP32 | problemSize = 1023 | A B^T | 2783 GFLOPS | |
// FP32 | problemSize = 1024 | A B | 3042 GFLOPS | |
// FP32 | problemSize = 1024 | A B^T | 3048 GFLOPS | |
// FP32 | problemSize = 1025 | A B | 2859 GFLOPS | |
// FP32 | problemSize = 1025 | A B^T | 2650 GFLOPS | |
// FP16 | problemSize = 1023 | A B | 3352 GFLOPS | |
// FP16 | problemSize = 1023 | A B^T | 3288 GFLOPS | |
// FP16 | problemSize = 1024 | A B | 3407 GFLOPS | |
// FP16 | problemSize = 1024 | A B^T | 3397 GFLOPS | |
// FP16 | problemSize = 1025 | A B | 3182 GFLOPS | |
// FP16 | problemSize = 1025 | A B^T | 3117 GFLOPS | |
// BF16 | problemSize = 1023 | A B | 3319 GFLOPS | |
// BF16 | problemSize = 1023 | A B^T | 3266 GFLOPS | |
// BF16 | problemSize = 1024 | A B | 3324 GFLOPS | |
// BF16 | problemSize = 1024 | A B^T | 3317 GFLOPS | |
// BF16 | problemSize = 1025 | A B | 3151 GFLOPS | |
// BF16 | problemSize = 1025 | A B^T | 3088 GFLOPS | |
// | |
// I got correctness working with the change to how edges are processed. | |
// Next, I need to get performance working. | |
// | |
// M1 Max: | |
// FP32 | problemSize = 1535 | A B | 8195 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7830 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8231 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8090 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7559 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7148 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 9087 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 8993 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9154 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9071 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8547 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8464 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 8209 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 8237 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8343 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8395 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7704 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7748 GFLOPS | |
// | |
// -14 GFLOPS for 1535. | |
// +2 GFLOPS for 1536. | |
// -68 GFLOPS for 1537. | |
// A larger performance drop for 1537 would be expected. Almost every element | |
// in the edge block is redundantly computed and written. | |
// | |
// M4: | |
// FP32 | problemSize = 1023 | A B | 3068 GFLOPS | |
// FP32 | problemSize = 1023 | A B^T | 2800 GFLOPS | |
// FP32 | problemSize = 1024 | A B | 3071 GFLOPS | |
// FP32 | problemSize = 1024 | A B^T | 3040 GFLOPS | |
// FP32 | problemSize = 1025 | A B | 2887 GFLOPS | |
// FP32 | problemSize = 1025 | A B^T | 2647 GFLOPS | |
// FP16 | problemSize = 1023 | A B | 3398 GFLOPS | |
// FP16 | problemSize = 1023 | A B^T | 3323 GFLOPS | |
// FP16 | problemSize = 1024 | A B | 3421 GFLOPS | |
// FP16 | problemSize = 1024 | A B^T | 3402 GFLOPS | |
// FP16 | problemSize = 1025 | A B | 3220 GFLOPS | |
// FP16 | problemSize = 1025 | A B^T | 3158 GFLOPS | |
// BF16 | problemSize = 1023 | A B | 3329 GFLOPS | |
// BF16 | problemSize = 1023 | A B^T | 3276 GFLOPS | |
// BF16 | problemSize = 1024 | A B | 3317 GFLOPS | |
// BF16 | problemSize = 1024 | A B^T | 3327 GFLOPS | |
// BF16 | problemSize = 1025 | A B | 3154 GFLOPS | |
// BF16 | problemSize = 1025 | A B^T | 3099 GFLOPS | |
// | |
// +26 GFLOPS for 1023. | |
// +8 GFLOPS for 1024. | |
// +20 GFLOPS for 1025. | |
// | |
// If every write goes through threadgroup memory (async copy) instead of | |
// device memory on M1, here is the alternative perf delta. | |
// | |
// FP32 | problemSize = 1535 | A B | 8226 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7862 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8251 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8124 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7615 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7202 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 9011 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 9012 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9162 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9049 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8575 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8539 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 7980 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 7954 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8352 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8399 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7716 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7906 GFLOPS | |
// | |
// -98 GFLOPS for 1535. | |
// +11 GFLOPS for 1536. | |
// -5 GFLOPS for 1537. | |
// | |
// The evidence reveals this: | |
// - write directly to device memory: | |
// -14 GFLOPS for 1535 | |
// +2 GFLOPS for 1536 | |
// - write indirectly via threadgroup memory: | |
// -5 GFLOPS for 1537 | |
// | |
// Focus on three optimizations: | |
// - Allowing extraneous writes to device or threadgroup memory to be elided. | |
// - Minimizing the overhead of pointer addressing arithmetic, restoring the | |
// performance of the kernel before this change. | |
// - Avoiding bank conflicts when writing the accumulator on M1 (if | |
// threadgroup memory is a common execution path). | |
// | |
// Eliding unnecessary writes, and going through device memory: | |
// | |
// M1 Max | |
// FP32 | problemSize = 1535 | A B | 8199 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7828 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8239 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8100 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7535 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7145 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 9092 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 8995 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9175 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9077 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8508 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8488 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 8224 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 8245 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8349 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8402 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7708 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7866 GFLOPS | |
// | |
// -8 GFLOPS for 1535. | |
// +12 GFLOPS for 1536. | |
// -60 GFLOPS for 1537. | |
// | |
// M4 | |
// FP32 | problemSize = 1023 | A B | 3064 GFLOPS | |
// FP32 | problemSize = 1023 | A B^T | 2783 GFLOPS | |
// FP32 | problemSize = 1024 | A B | 3064 GFLOPS | |
// FP32 | problemSize = 1024 | A B^T | 3039 GFLOPS | |
// FP32 | problemSize = 1025 | A B | 2886 GFLOPS | |
// FP32 | problemSize = 1025 | A B^T | 2647 GFLOPS | |
// FP16 | problemSize = 1023 | A B | 3405 GFLOPS | |
// FP16 | problemSize = 1023 | A B^T | 3333 GFLOPS | |
// FP16 | problemSize = 1024 | A B | 3417 GFLOPS | |
// FP16 | problemSize = 1024 | A B^T | 3400 GFLOPS | |
// FP16 | problemSize = 1025 | A B | 3217 GFLOPS | |
// FP16 | problemSize = 1025 | A B^T | 3163 GFLOPS | |
// BF16 | problemSize = 1023 | A B | 3331 GFLOPS | |
// BF16 | problemSize = 1023 | A B^T | 3277 GFLOPS | |
// BF16 | problemSize = 1024 | A B | 3318 GFLOPS | |
// BF16 | problemSize = 1024 | A B^T | 3325 GFLOPS | |
// BF16 | problemSize = 1025 | A B | 3154 GFLOPS | |
// BF16 | problemSize = 1025 | A B^T | 3097 GFLOPS | |
// | |
// +26 GFLOPS for 1023. | |
// +5 GFLOPS for 1024. | |
// +20 GFLOPS for 1025. | |
// | |
// Eliding unnecessary writes, and going through threadgroup memory: | |
// | |
// M1 Max | |
// FP32 | problemSize = 1535 | A B | 8221 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7861 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8249 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8126 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7611 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7203 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 8995 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 9012 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9190 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9046 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8574 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8512 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 7988 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 7962 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8358 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8405 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7709 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7872 GFLOPS | |
// | |
// -99 GFLOPS for 1535. | |
// +17 GFLOPS for 1536. | |
// -17 GFLOPS for 1537. | |
// | |
// M4 | |
// FP32 | problemSize = 1023 | A B | 3065 GFLOPS | |
// FP32 | problemSize = 1023 | A B^T | 2795 GFLOPS | |
// FP32 | problemSize = 1024 | A B | 3065 GFLOPS | |
// FP32 | problemSize = 1024 | A B^T | 3030 GFLOPS | |
// FP32 | problemSize = 1025 | A B | 2874 GFLOPS | |
// FP32 | problemSize = 1025 | A B^T | 2648 GFLOPS | |
// FP16 | problemSize = 1023 | A B | 3375 GFLOPS | |
// FP16 | problemSize = 1023 | A B^T | 3313 GFLOPS | |
// FP16 | problemSize = 1024 | A B | 3403 GFLOPS | |
// FP16 | problemSize = 1024 | A B^T | 3384 GFLOPS | |
// FP16 | problemSize = 1025 | A B | 3192 GFLOPS | |
// FP16 | problemSize = 1025 | A B^T | 3139 GFLOPS | |
// BF16 | problemSize = 1023 | A B | 3334 GFLOPS | |
// BF16 | problemSize = 1023 | A B^T | 3256 GFLOPS | |
// BF16 | problemSize = 1024 | A B | 3344 GFLOPS | |
// BF16 | problemSize = 1024 | A B^T | 3308 GFLOPS | |
// BF16 | problemSize = 1025 | A B | 3156 GFLOPS | |
// BF16 | problemSize = 1025 | A B^T | 3093 GFLOPS | |
// | |
// +17 GFLOPS for 1023. | |
// -0 GFLOPS for 1024. | |
// +9 GFLOPS for 1025. | |
// | |
// I want to test whether any performance changes are simply due to changes | |
// to the pointer addressing code. I will modify the code so that no memory | |
// writes are elided. Then, benchmark the following configurations: | |
// - M1 Max | |
// - 1535, to device memory | |
// - 1536, to device memory | |
// - 1537, to threadgroup memory | |
// - M4 | |
// - 1023, to device memory | |
// - 1024, to device memory | |
// - 1025, to device memory | |
// | |
// M1 Max | |
// FP32 | problemSize = 1535 | A B | 8192 GFLOPS | |
// FP32 | problemSize = 1535 | A B^T | 7834 GFLOPS | |
// FP32 | problemSize = 1536 | A B | 8245 GFLOPS | |
// FP32 | problemSize = 1536 | A B^T | 8097 GFLOPS | |
// FP32 | problemSize = 1537 | A B | 7618 GFLOPS | |
// FP32 | problemSize = 1537 | A B^T | 7206 GFLOPS | |
// FP16 | problemSize = 1535 | A B | 9079 GFLOPS | |
// FP16 | problemSize = 1535 | A B^T | 9007 GFLOPS | |
// FP16 | problemSize = 1536 | A B | 9183 GFLOPS | |
// FP16 | problemSize = 1536 | A B^T | 9078 GFLOPS | |
// FP16 | problemSize = 1537 | A B | 8581 GFLOPS | |
// FP16 | problemSize = 1537 | A B^T | 8541 GFLOPS | |
// BF16 | problemSize = 1535 | A B | 8219 GFLOPS | |
// BF16 | problemSize = 1535 | A B^T | 8247 GFLOPS | |
// BF16 | problemSize = 1536 | A B | 8351 GFLOPS | |
// BF16 | problemSize = 1536 | A B^T | 8399 GFLOPS | |
// BF16 | problemSize = 1537 | A B | 7723 GFLOPS | |
// BF16 | problemSize = 1537 | A B^T | 7908 GFLOPS | |
// | |
// -9 GFLOPS for 1535. | |
// +14 GFLOPS for 1536. | |
// -1 GFLOPS for 1537. | |
// | |
// M4 | |
// FP32 | problemSize = 1023 | A B | 3067 GFLOPS | |
// FP32 | problemSize = 1023 | A B^T | 2780 GFLOPS | |
// FP32 | problemSize = 1024 | A B | 3072 GFLOPS | |
// FP32 | problemSize = 1024 | A B^T | 3038 GFLOPS | |
// FP32 | problemSize = 1025 | A B | 2888 GFLOPS | |
// FP32 | problemSize = 1025 | A B^T | 2644 GFLOPS | |
// FP16 | problemSize = 1023 | A B | 3394 GFLOPS | |
// FP16 | problemSize = 1023 | A B^T | 3325 GFLOPS | |
// FP16 | problemSize = 1024 | A B | 3418 GFLOPS | |
// FP16 | problemSize = 1024 | A B^T | 3400 GFLOPS | |
// FP16 | problemSize = 1025 | A B | 3220 GFLOPS | |
// FP16 | problemSize = 1025 | A B^T | 3158 GFLOPS | |
// BF16 | problemSize = 1023 | A B | 3330 GFLOPS | |
// BF16 | problemSize = 1023 | A B^T | 3275 GFLOPS | |
// BF16 | problemSize = 1024 | A B | 3329 GFLOPS | |
// BF16 | problemSize = 1024 | A B^T | 3325 GFLOPS | |
// BF16 | problemSize = 1025 | A B | 3153 GFLOPS | |
// BF16 | problemSize = 1025 | A B^T | 3098 GFLOPS | |
// | |
// +22 GFLOPS for 1023. | |
// +8 GFLOPS for 1024. | |
// +19 GFLOPS for 1025. | |
// | |
// The write elision optimization has failed, so I need to revert it. I did | |
// succeed at shifting the matrix origin in-bounds, allowing M4 to avoid async | |
// stores for almost every matrix in practical use cases. The new code has | |
// minimal changes to M1 performance and non-negligible improvements to M4 | |
// performance. Averaging the ALU utilization for all precisions: | |
// | |
// M1 Max (max 10617 GFLOPS) | |
// 1535: 79.5% -> 79.4% | |
// 1536: 80.5% -> 80.6% | |
// 1537: 74.7% -> 74.7% | |
// | |
// M4 (max ~3580 GFLOPS) | |
// 1023: 88.6% -> 89.3% | |
// 1024: 90.9% -> 91.2% | |
// 1025: 84.0% -> 84.5% | |
// | |
// The M1 architecture changes only by rounding error (0.1%). The M3 | |
// architecture increases by anywhere from +0.3% to +0.7%. This evidence seems | |
// high-quality enough to justify the recent refactoring of the shader. | |
// | |
// There is one slight complication. To avoid a regression on M1, I need to | |
// switch between device and threadgroup writing based on the result of | |
// (problem size / block size). I'm going to investigate whether there are | |
// bank conflicts on M1. If so, that could bias the 'threadgroup' performance | |
// to always be the highest. I would not need to engineer logic for switching | |
// between shader variants. | |
// | |
// ========================================================================== // | |
// Tuning BF16 on M1-M2 | |
// ========================================================================== // | |
// | |
// I gathered some evidence, and noticed strange patterns between which is | |
// faster for M1 / BF16: device or threadgroup stores? | |
// | |
// https://gist.github.com/philipturner/1c157bf87702420e51c6deb68e70d078 | |
// | |
// Here are some raw chat messages from Discord: | |
// | |
// > Today I found out a lot of strange things about my shader, that may be the | |
// > reason it's faster than MPS. | |
// > | |
// > The K dimension of the block (accumulator) is being unrolled. | |
// > | |
// > But all the registers for K are being allocated. Normally a SIMD allocates | |
// > 32 x 32 for accumulator, 32 x 8 for A, 32 x 8 for B. This is the way MPS | |
// > does it, and how any other Apple GPU (e.g. M3) does it. How the very old | |
// > Tinygrad reference implementation did it. Also the fastest when SIMD async | |
// > copy is not used. | |
// > | |
// > MFA has 24 x 24 for accumulator, 24 x 24 for A, and 24 x 24 for B. So at | |
// > first glance, almost 3 times the register pressure if the accumulator only | |
// > was stored. That means occupancy is terribly low. Revealed itself for BF16 | |
// > where although threadgroup memory is BF16, registers are FP32. | |
// > | |
// > I tried fixing it, but every alternative lowered performance. Apparently | |
// > the biggest bottleneck is either addressing overhead or latency of waiting | |
// > on threadgroup -> thread loads (somehow the GPU needs a heck ton of | |
// > latency hiding and wants to load ~24-40 inner loop iterations of data up | |
// > front). | |
// > | |
// > MPS would not have found this strange performance maximum for a number of | |
// > reasons. They don't have good support for matrices indivisible by block | |
// > size. 32 x 32 + 32 x 16 + 32 x 16 would cause severe enough register | |
// > pressure to be nonviable. Only 24 x 24 + 24 x 24 + 24 x 24 can work with | |
// > this. And it needs to be paired with async copy, because the small M/N in | |
// > registers requires high bandwidth. Only possible when reading from | |
// > threadgroup memory (higher bandwidth than device memory). | |
// > | |
// > For FP16 the maximum is 24 x 24 + 24 x 32 + 24 x 32 (registers are FP16) | |
// > and BF16 is strangely 24 x 24 + 24 x 32 + 24 x 32 when the registers are | |
// > FP32. Extremely low occupancy for this performance maximum. That's why the | |
// > heuristics for kernel selection logic will be tricky. | |
// > | |
// > I think I can design the logic for M1 + BF16 in a day. | |
// | |
// That is exactly what I will do to complete the shader. Compile a large | |
// volume of performance data, create truth tables, and solve the inverse | |
// problem of predicting the fastest configuration. | |
// | |
// Data sheet: | |
// | |
// https://gist.github.com/philipturner/45781f4515145106fc0d4e598dd5f13b | |
// | |
// Insights from the data: | |
// | |
// Always allocate an accumulator worth of threadgroup memory, even if it will | |
// never be used. This speeds up M3 for some reason. Except for a -50 to -100 | |
// GFLOPS regression when all operands are FP16 and the matrix dimension is | |
// indivisible by 8. Most importantly, the choice will simplify the code. This | |
// is an example of deciding between various pros and cons of an optimization. | |
// | |
// > Gathering and interpreting data (should be) science, but deciding how to | |
// > act on that data is engineering. A general rule: choose simpler logic | |
// > that will extrapolate beyond the training data. Favor design choices | |
// > that are robust against edge cases. Even if it sacrifices performance on | |
// > benchmarks. | |
// | |
// Next, I found the primary reason for inconsistent performance with M1 + | |
// BF16. I had not finished optimizing the FP32 -> BF16 encoding when storing | |
// data to memory. The fastest path is the one that uses two separate | |
// store instructions. There is no evidence this change should be applied to | |
// precisions besides BF16. I will modify the codegen for 'store_bfloat' in | |
// 'simdgroup_matrix_storage'. | |
// | |
// Finally, whether to store the accumulator with an async copy. | |
// | |
// ``` | |
// if the matrix dimensions M and/or N are smaller than the corresponding | |
// block size | |
// use async copy | |
// else if M3+ | |
// do not use async copy | |
// else | |
// continue to next statement | |
// | |
// if the matrix is small (32 x 32 x 32 block) | |
// decouple the concern of optimizing small matrices from the concern of | |
// async stores | |
// | |
// use async copy by default | |
// else if any operand is FP32 | |
// decouple the concern of eliminating bank conflicts from the concern of | |
// async stores | |
// | |
// use async copy by default | |
// else | |
// continue to next statement | |
// | |
// assert that block size is (M = 48, N = 48, K = 32) | |
// compile four kernel variants: | |
// - store directly to device, K = 32 | |
// - store directly to device, K = 40 | |
// - store indirectly to threadgroup, K = 32 | |
// - store indirectly to threadgroup, K = 40 | |
// create a MTLComputePipelineState for each variant | |
// | |
// sort the pipelines by occupancy | |
// if two pipelines have a different occupancy | |
// the higher occupancy wins | |
// else | |
// choose the last pipeline (they're ordered by increasing performance) | |
// ``` | |
// | |
// Data for validating correct implementation of the logic: | |
// | |
// M1 Max, BF16 | |
// | |
// K = 32 | |
// problemSize = 1488 | A B | 896 threads/core | 8360 GFLOPS | |
// problemSize = 1488 | A B^T | 1024 threads/core | 8680 GFLOPS | |
// problemSize = 1488 | A^T B | 1024 threads/core | 8790 GFLOPS | |
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9036 GFLOPS | |
// problemSize = 1489 | A B | 768 threads/core | 7975 GFLOPS | |
// problemSize = 1489 | A B^T | 832 threads/core | 8086 GFLOPS | |
// problemSize = 1489 | A^T B | 768 threads/core | 8177 GFLOPS | |
// problemSize = 1489 | A^T B^T | 832 threads/core | 8454 GFLOPS | |
// | |
// K = 40 | |
// problemSize = 1488 | A B | 640 threads/core | 8183 GFLOPS | |
// problemSize = 1488 | A B^T | 704 threads/core | 8421 GFLOPS | |
// problemSize = 1488 | A^T B | 768 threads/core | 8654 GFLOPS | |
// problemSize = 1488 | A^T B^T | 768 threads/core | 8890 GFLOPS | |
// problemSize = 1489 | A B | 768 threads/core | 8018 GFLOPS | |
// problemSize = 1489 | A B^T | 832 threads/core | 8382 GFLOPS | |
// problemSize = 1489 | A^T B | 768 threads/core | 8320 GFLOPS | |
// problemSize = 1489 | A^T B^T | 832 threads/core | 8637 GFLOPS | |
// | |
// Correctness was confirmed: | |
// | |
// problemSize = 1488 | A B | 896 threads/core | 8358 GFLOPS | |
// problemSize = 1488 | A B^T | 1024 threads/core | 8674 GFLOPS | |
// problemSize = 1488 | A^T B | 1024 threads/core | 8794 GFLOPS | |
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9024 GFLOPS | |
// problemSize = 1489 | A B | 768 threads/core | 8028 GFLOPS | |
// problemSize = 1489 | A B^T | 832 threads/core | 8374 GFLOPS | |
// problemSize = 1489 | A^T B | 832 threads/core | 8358 GFLOPS | |
// problemSize = 1489 | A^T B^T | 832 threads/core | 8645 GFLOPS | |
// | |
// This concludes the development of the unified GEMM kernel. The only further | |
// changes to this GitHub gist will be bug fixes and typo fixes. | |
// | |
// ========================================================================== // | |
// ========================================================================== // | |
// ========================================================================== // | |
// ## Addendum | |
// | |
// I fixed the issues with the shader caching mechanism. It now has a nearly | |
// 100% hit rate for the library cache. Interestingly, both MTLLibrary creation | |
// and MTLComputePipelineState specialization are roughly 30 ms. So the latency | |
// only decreased from 60 ms to 30 ms in my tests. | |
// | |
// Expected performance: | |
// | |
// problemSize = 511 | A B | 1024 threads/core | 6706 GFLOPS | |
// problemSize = 511 | A B^T | 896 threads/core | 5631 GFLOPS | |
// problemSize = 511 | A^T B | 896 threads/core | 5824 GFLOPS | |
// problemSize = 511 | A^T B^T | 1024 threads/core | 6991 GFLOPS | |
// problemSize = 512 | A B | 1024 threads/core | 6842 GFLOPS | |
// problemSize = 512 | A B^T | 1024 threads/core | 6938 GFLOPS | |
// problemSize = 512 | A^T B | 896 threads/core | 5933 GFLOPS | |
// problemSize = 512 | A^T B^T | 1024 threads/core | 7208 GFLOPS | |
// | |
// problemSize = 1488 | A B | 896 threads/core | 8375 GFLOPS | |
// problemSize = 1488 | A B^T | 1024 threads/core | 8685 GFLOPS | |
// problemSize = 1488 | A^T B | 1024 threads/core | 8804 GFLOPS | |
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9039 GFLOPS | |
// problemSize = 1489 | A B | 768 threads/core | 8049 GFLOPS | |
// problemSize = 1489 | A B^T | 832 threads/core | 8387 GFLOPS | |
// problemSize = 1489 | A^T B | 832 threads/core | 8381 GFLOPS | |
// problemSize = 1489 | A^T B^T | 832 threads/core | 8652 GFLOPS | |
// | |
// Performance after the change: | |
// | |
// problemSize = 511 | A B | 1024 threads/core | 6784 GFLOPS | |
// problemSize = 511 | A B^T | 896 threads/core | 5651 GFLOPS | |
// problemSize = 511 | A^T B | 896 threads/core | 5818 GFLOPS | |
// problemSize = 511 | A^T B^T | 1024 threads/core | 6997 GFLOPS | |
// problemSize = 512 | A B | 1024 threads/core | 6820 GFLOPS | |
// problemSize = 512 | A B^T | 1024 threads/core | 6915 GFLOPS | |
// problemSize = 512 | A^T B | 896 threads/core | 5966 GFLOPS | |
// problemSize = 512 | A^T B^T | 1024 threads/core | 7168 GFLOPS | |
// | |
// problemSize = 1488 | A B | 896 threads/core | 8369 GFLOPS | |
// problemSize = 1488 | A B^T | 1024 threads/core | 8678 GFLOPS | |
// problemSize = 1488 | A^T B | 1024 threads/core | 8808 GFLOPS | |
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9040 GFLOPS | |
// problemSize = 1489 | A B | 768 threads/core | 8048 GFLOPS | |
// problemSize = 1489 | A B^T | 832 threads/core | 8391 GFLOPS | |
// problemSize = 1489 | A^T B | 832 threads/core | 8381 GFLOPS | |
// problemSize = 1489 | A^T B^T | 832 threads/core | 8648 GFLOPS | |
// MARK: - GEMM Kernel | |
/// An enumeration of the precisions supported by the kernel. | |
/// | |
/// If you wish to support quantized precisions, copy/translate the source code | |
/// and integrate a modified version into your app. Something similar to a Swift | |
/// `enum` (e.g. C++ `enum class`) could enumerate the quantization formats | |
/// used by application code. An exemplary set could be: | |
/// - FP32 | |
/// - FP16 | |
/// - BF16 | |
/// - signed 8-bit integer | |
/// - s1ezm7 | |
/// - FP8 | |
/// - palletized | |
/// | |
/// If you support non-floating-point formats, you have the responsibility of | |
/// authoring correct and performant GPU code for them. A general rule of thumb, | |
/// is keep the data compressed in `device` or `threadgroup` memory. Transform | |
/// into a floating point type while loading into the registers. Keep the | |
/// accumulator in floating point until the output needs to be written. | |
/// If the output is quantized, it will be compressed when writing back to | |
/// `device` memory (or `threadgroup` before the async copy in edge cases). | |
/// | |
/// For example, the reference implementation treats BF16 like a quantized | |
/// integer type on Apple7 and Apple8 GPUs. It is decompressed to FP32 in | |
/// registers. | |
enum GEMMOperandPrecision: UInt16 { | |
case FP32 = 0 | |
case FP16 = 1 | |
case BF16 = 2 | |
} | |
/// A configuration for a GEMM kernel. | |
/// | |
/// The information in this data structure is enough to uniquely identify the | |
/// kernel. It can be used as a key in a key-value cache. | |
/// | |
/// ## Usage | |
/// | |
/// The code for generating the GEMM kernel does not include any assumptions | |
/// about performance. It should only be responsible for correctly generating | |
/// a shader source, provided a configuration. The user is responsible for | |
/// choosing that configuration. | |
struct GEMMKernelDescriptor { | |
/// Required. The number of matrix elements spanned by each threadgroup. | |
/// - Parameter M: Number of output columns spanned. | |
/// - Parameter N: Number of output rows spanned. | |
/// - Parameter K: Number of loop iterations unrolled. | |
/// | |
/// Optimal values: | |
/// - Apple7 and Apple8: 48x48x24 | |
/// - Apple9 and later: 32x32x8 | |
/// | |
/// To reach optimal performance on Apple7 and Apple8, the recommended default | |
/// value needs to be modified conditionally. When all three operands have | |
/// 16-bit memory precisions, change `K` to 32. When the matrix is too small | |
/// to saturate all of the GPU cores, change all dimensions to 32x32x32. Even | |
/// smaller blocks can be exploited in low-occupancy cases, but 32x32 and | |
/// 48x48 are sufficient for general use. | |
/// | |
/// For simplicity or an out-of-the-box performance test, one can assume | |
/// occupancy is always high. But to match the performance of MPS, one must | |
/// optimize for small problem sizes on large GPUs. | |
/// | |
/// ## Choosing Block Size by Precision | |
/// | |
/// Legend: | |
/// - memA: precision for left input matrix, in memory | |
/// - memB: precision for right input matrix, in memory | |
/// - memC: precision for output matrix, in memory | |
/// - regA: precision for left input matrix, in registers | |
/// - regB: precision for right input matrix, in registers | |
/// - regC: precision for output matrix, in registers | |
/// - M1: optimal block size on Apple7 and Apple8 | |
/// - M3: optimal block size on Apple9 and later | |
/// | |
/// memA | memB | memC | regA | regB | regC | M1 | M3 | | |
/// ---- | ---- | ---- | ---- | ---- | ---- | -------- | ------- | | |
/// FP16 | FP16 | FP16 | any | any | any | 48x48x32 | 32x32x8 | | |
/// BF16 | BF16 | BF16 | any | any | any | 48x48x32 | 32x32x8 | | |
/// FP16 | FP16 | FP32 | any | any | any | 48x48x24 | 32x32x8 | | |
/// BF16 | BF16 | FP32 | any | any | any | 48x48x24 | 32x32x8 | | |
/// FP16 | FP32 | FP16 | any | any | any | 48x48x24 | 32x32x8 | | |
/// BF16 | FP32 | BF16 | any | any | any | 48x48x24 | 32x32x8 | | |
/// FP32 | FP32 | FP32 | any | any | any | 48x48x24 | 32x32x8 | | |
/// | |
/// ## Detecting Low-Occupancy Cases | |
/// | |
/// To determine whether the matrix saturates the GPU, divide the output | |
/// matrix's dimensions by 48x48. Round up to the nearest integer. Then, | |
/// multiply the number of row blocks by the number of column blocks. The | |
/// result is the number of threadgroups dispatched. For example, a C matrix | |
/// with dimensions 768x768 would dispatch 256 threadgroups. If you are | |
/// batching multiple matrix multiplications into one shader call, multiply | |
/// the number of threadgroups by the batch count. | |
/// | |
/// Next, calculate the target occupancy. Start by finding the GPU core count. | |
/// This can be accomplished in many ways; there is a heavily tested reference | |
/// implementation [here](https://github.com/philipturner/applegpuinfo). On | |
/// macOS, you can query the core count through IORegistry. On iOS, go with a | |
/// conservative (meaning more likely to overestimate) estimate of 5 cores on | |
/// A14 - A16, 10 cores on M1 - M2. | |
/// | |
/// When one of the operands is 32-bit, the target occupancy is 6 threadgroups | |
/// per core. When all three operands are 16-bit, the target increases to 9 | |
/// per core. Multiply the number of cores by the number of threadgroups per | |
/// core. If the total GPU occupancy is greater than or equal to the number of | |
/// matrix blocks, use the smaller blocking scheme. | |
/// | |
/// For example, the following decision tree would be used on an M1 Max | |
/// (32 cores). | |
/// | |
/// ``` | |
/// is device Apple9 or later? | |
/// yes: use block size 32x32x8 | |
/// no: continue decision tree [selected decision] | |
/// unsure: use block size 48x48x24-32 | |
/// | |
/// compute number of matrix blocks | |
/// 768x768 / 48x48 = 16.0 x 16.0 | |
/// round floating point (16.0 x 16.0) | |
/// to next greatest integer (16 x 16) | |
/// 16 x 16 x (batch size of 1) = 256 threadgroups | |
/// | |
/// compute target occupancies with 48x48 scheme | |
/// 32 x 6 = 192 [selected when A, B, or C is FP32] | |
/// 32 x 9 = 288 [selected when every matrix is FP16/BF16] | |
/// | |
/// prefer 32x32 when 48x48 has low occupancy | |
/// if 256 ≤ 192 | |
/// choose small block size (32x32x32xFP32) | |
/// else | |
/// choose large block size (48x48x24xFP32) [selected] | |
/// if 256 ≤ 288 | |
/// choose small block size (32x32x32xFP16) [selected] | |
/// else | |
/// choose large block size (48x48x32xFP16) | |
/// ``` | |
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16)? | |
var memoryPrecisions: ( | |
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)? | |
/// The device to create the kernel on. | |
var device: MTLDevice? | |
/// Optional. The layout of elements in threadgroup memory. | |
/// | |
/// If not specified, the default value matches the actual block dimensions. | |
/// | |
/// This property can be used to avoid bank conflicts. For example, of one | |
/// operand will have 16 FP32 elements per row, there is good chance of | |
/// increased bank conflicts on M1. One may pad that threadgroup memory | |
/// allocation to 20 FP32 elements per row. | |
/// | |
/// Note that the assignment of M/N/K to row dimensions varies based on which | |
/// operand is discussed, and what its transpose state is. | |
var paddedBlockDimensions: ( | |
A: (M: UInt16, K: UInt16), | |
B: (K: UInt16, N: UInt16), | |
C: (M: UInt16, N: UInt16))? | |
/// Required. Whether async copies will improve performance during the | |
/// matrix multiplication loop. | |
/// | |
/// The default value is `true`. Async copies improve performance on Apple7 | |
/// and Apple8, but harm performance on Apple9 and later. However, they are | |
/// essential for correctness when reading from the edges of unaligned | |
/// matrices. Setting the value to `false` means skipping async copies when | |
/// doing so will not change the final result. | |
var preferAsyncLoad: Bool = true | |
/// Required. Whether async copies will improve performance when storing the | |
/// accumulator to main memory. | |
/// | |
/// There is no default value that will reliably yield consistent performance. | |
var preferAsyncStore: Bool? | |
/// Set the register precision based on the GPU architecture, and your choice | |
/// for memory precision. The following set of logic statements should provide | |
/// optimal performance for all permutations of operand precisions. | |
/// | |
/// ``` | |
/// regA is identical to memA | |
/// regB is identical to memB | |
/// If memA, memB, and memC are FP16, | |
/// regC is FP16 | |
/// else | |
/// regC is FP32 | |
/// | |
/// If earlier than M3 | |
/// If memA is BF16, | |
/// regA is FP32 | |
/// If memB is BF16, | |
/// regB is FP32 | |
/// ``` | |
var registerPrecisions: ( | |
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)? | |
/// Required. The array of SIMDs to divide the threadgroup into. | |
/// | |
/// Optimal values: | |
/// - Apple7 and Apple8: 2x2 | |
/// - Apple9 and later: 1x1 | |
var splits: (M: UInt16, N: UInt16)? | |
/// Required. Whether each of the inputs deviates from row-major order. | |
var transposeState: (A: Bool, B: Bool)? | |
} | |
struct GEMMKernelKey: Equatable, Hashable { | |
var blockDimensions: SIMD3<UInt16> | |
var memoryPrecisions: SIMD3<UInt16> | |
var paddedBlockDimensions: SIMD8<UInt16> | |
var preferAsyncLoad: UInt8 | |
var preferAsyncStore: UInt8 | |
var registerPrecisions: SIMD3<UInt16> | |
var splits: SIMD2<UInt16> | |
var transposeState: SIMD2<UInt8> | |
init(copying source: GEMMKernelDescriptor) { | |
blockDimensions = Self.createBlockDimensions(source.blockDimensions) | |
memoryPrecisions = Self.createPrecisions(source.memoryPrecisions) | |
paddedBlockDimensions = SIMD8(repeating: .max) | |
if let (A, B, C) = source.paddedBlockDimensions { | |
paddedBlockDimensions[0] = A.0 | |
paddedBlockDimensions[1] = A.1 | |
paddedBlockDimensions[2] = B.0 | |
paddedBlockDimensions[3] = B.1 | |
paddedBlockDimensions[4] = C.0 | |
paddedBlockDimensions[5] = C.1 | |
} | |
preferAsyncLoad = Self.createBoolean(source.preferAsyncLoad) | |
preferAsyncStore = Self.createBoolean(source.preferAsyncStore) | |
registerPrecisions = Self.createPrecisions(source.registerPrecisions) | |
splits = SIMD2(repeating: .max) | |
if let (M, N) = source.splits { | |
splits[0] = M | |
splits[1] = N | |
} | |
transposeState = Self.createTransposeState(source.transposeState) | |
} | |
@_transparent // performance in -Ounchecked | |
static func createBlockDimensions( | |
_ input: (UInt16, UInt16, UInt16)? | |
) -> SIMD3<UInt16> { | |
if let input { | |
return SIMD3(input.0, input.1, input.2) | |
} else { | |
return SIMD3(repeating: .max) | |
} | |
} | |
@_transparent // performance in -Ounchecked | |
static func createBoolean( | |
_ input: Bool? | |
) -> UInt8 { | |
if let input { | |
return input ? 1 : 0 | |
} else { | |
return UInt8.max | |
} | |
} | |
@_transparent // performance in -Ounchecked | |
static func createPrecisions( | |
_ input: ( | |
GEMMOperandPrecision, GEMMOperandPrecision, GEMMOperandPrecision)? | |
) -> SIMD3<UInt16> { | |
if let input { | |
return SIMD3(input.0.rawValue, input.1.rawValue, input.2.rawValue) | |
} else { | |
return SIMD3(repeating: .max) | |
} | |
} | |
@_transparent // performance in -Ounchecked | |
static func createTransposeState( | |
_ input: (Bool, Bool)? | |
) -> SIMD2<UInt8> { | |
if let input { | |
return SIMD2(input.0 ? 1 : 0, | |
input.1 ? 1 : 0) | |
} else { | |
return SIMD2(repeating: .max) | |
} | |
} | |
} | |
extension GEMMKernelDescriptor: Hashable, Equatable { | |
static func == (lhs: GEMMKernelDescriptor, rhs: GEMMKernelDescriptor) -> Bool { | |
let lhsKey = GEMMKernelKey(copying: lhs) | |
let rhsKey = GEMMKernelKey(copying: rhs) | |
return lhsKey == rhsKey | |
} | |
func hash(into hasher: inout Hasher) { | |
let key = GEMMKernelKey(copying: self) | |
hasher.combine(key) | |
} | |
} | |
struct GEMMKernel { | |
var library: MTLLibrary | |
var source: String = "" | |
// A copy of the block dimensions from the descriptor. | |
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16) | |
// If you allocate threadgroup memory after compiling the kernel, the code | |
// has higher performance. | |
var threadgroupMemoryAllocation: UInt16 | |
// The number of threads per group. | |
var threadgroupSize: UInt16 | |
init(descriptor: GEMMKernelDescriptor) { | |
guard let blockDimensions = descriptor.blockDimensions, | |
let device = descriptor.device, | |
let memoryPrecisions = descriptor.memoryPrecisions, | |
let preferAsyncStore = descriptor.preferAsyncStore, | |
let registerPrecisions = descriptor.registerPrecisions, | |
let splits = descriptor.splits, | |
let transposeState = descriptor.transposeState else { | |
fatalError("Descriptor was incomplete: \(descriptor)") | |
} | |
self.blockDimensions = blockDimensions | |
self.threadgroupSize = 32 * splits.M * splits.N | |
// Validate the correctness of register precisions. | |
func checkOperandPair( | |
memory: GEMMOperandPrecision, | |
register: GEMMOperandPrecision | |
) -> Bool { | |
// Truth table: | |
// | |
// memory | register | valid | | |
// ------ | -------- | ----- | | |
// FP32 | FP32 | yes | | |
// FP32 | FP16 | no | | |
// FP32 | BF16 | no | | |
// FP16 | FP32 | yes | | |
// FP16 | FP16 | yes | | |
// FP16 | BF16 | no | | |
// BF16 | FP32 | yes | | |
// BF16 | FP16 | no | | |
// BF16 | BF16 | yes | | |
// | |
// Optimized form of the logic: | |
// | |
// If the register precision matches the memory precision, | |
// return true | |
// If the register precision equals FP32, | |
// return true | |
// Otherwise, | |
// return false | |
// | |
// The logic statements will change if you introduce custom quantized | |
// formats. The truth table will grow exponentially. You'll need to add | |
// more restrictions on accepted pairs to overcome the combinatorial | |
// explosion. | |
if register == memory { | |
return true | |
} else if register == .FP32 { | |
return true | |
} else { | |
return false | |
} | |
} | |
guard checkOperandPair( | |
memory: memoryPrecisions.A, register: registerPrecisions.A) else { | |
fatalError("Operand A had an invalid register precision.") | |
} | |
guard checkOperandPair( | |
memory: memoryPrecisions.B, register: registerPrecisions.B) else { | |
fatalError("Operand B had an invalid register precision.") | |
} | |
guard checkOperandPair( | |
memory: memoryPrecisions.C, register: registerPrecisions.C) else { | |
fatalError("Operand C had an invalid register precision.") | |
} | |
if registerPrecisions.C == .BF16 { | |
// BF16 has too few mantissa bits to be an accurate accumulator. In | |
// addition, switching from FP32 accumulator to BF16 accumulator slows | |
// down execution speed on both M1/M2 and M3+. | |
fatalError("BF16 cannot be used as the register precision for C.") | |
} | |
// Inject the contents of the headers. | |
source += """ | |
\(createMetalSimdgroupEvent()) | |
\(createMetalSimdgroupMatrixStorage()) | |
using namespace metal; | |
""" | |
// Declare the size of M and N within a register allocation. | |
let registerM: UInt16 = blockDimensions.M / splits.M | |
let registerN: UInt16 = blockDimensions.N / splits.N | |
// Retrieve the "padded" block dimensions, otherwise compute analytically | |
// from the true block dimensions. | |
var paddedBlockDimensionsA: (M: UInt16, K: UInt16) | |
var paddedBlockDimensionsB: (K: UInt16, N: UInt16) | |
var paddedBlockDimensionsC: (M: UInt16, N: UInt16) | |
if let paddedBlockDimensions = descriptor.paddedBlockDimensions { | |
paddedBlockDimensionsA = paddedBlockDimensions.A | |
paddedBlockDimensionsB = paddedBlockDimensions.B | |
paddedBlockDimensionsC = paddedBlockDimensions.C | |
} else { | |
paddedBlockDimensionsA = (blockDimensions.M, blockDimensions.K) | |
paddedBlockDimensionsB = (blockDimensions.K, blockDimensions.N) | |
paddedBlockDimensionsC = (blockDimensions.M, blockDimensions.N) | |
} | |
// Determine the block dimensions from the transpose state. | |
var leadingDimensionA: String | |
var leadingDimensionB: String | |
var leadingBlockDimensionA: UInt16 | |
var leadingBlockDimensionB: UInt16 | |
if transposeState.A { | |
leadingDimensionA = "M" | |
leadingBlockDimensionA = paddedBlockDimensionsA.M | |
} else { | |
leadingDimensionA = "K" | |
leadingBlockDimensionA = paddedBlockDimensionsA.K | |
} | |
if transposeState.B { | |
leadingDimensionB = "K" | |
leadingBlockDimensionB = paddedBlockDimensionsB.K | |
} else { | |
leadingDimensionB = "N" | |
leadingBlockDimensionB = paddedBlockDimensionsB.N | |
} | |
// Add the function constants. | |
do { | |
source += """ | |
// Dimensions of each matrix. | |
// - Limitations to matrix size: | |
// - 2^32 in each dimension (M/N/K). | |
// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a | |
// good chance this will significantly degrade performance, and require | |
// changing the data type of several variables that process addresses. The | |
// client is responsible for ensuring correctness and performance with | |
// matrices spanning several billion elements in one direction. | |
// - The matrix dimensions must be known at compile time, via function | |
// constants. Dynamic matrix shapes are beyond the scope of this reference | |
// implementation. Dynamic shapes cause a non-negligible regression to | |
// shader execution speed. However, they could minimize a compilation | |
// latency bottleneck in some use cases. | |
// - Limitations to batch size: | |
// - Dictated by how the client modifies the code to implement batching. | |
// - Dynamic batch shapes would likely not harm performance much. For example, | |
// someone could enter an array of pointers/memory offsets to different | |
// matrices in the batch. Each slice of a 3D thread grid could read a | |
// different pointer from memory, and use that pointer as the A/B/C matrix. | |
// Another approach is to restrict the input format, so all matrices are | |
// stored contiguously in memory. Then, the memory offset could be computed | |
// analytically from matrix size and the Z dimension in a 3D thread grid. | |
// | |
// Another note: | |
// - The rows of the matrix must be contiguous in memory. Supporting strides | |
// that differ from the actual matrix dimensions should not be difficult, but | |
// it is out of scope for this reference kernel. | |
constant uint M [[function_constant(0)]]; | |
constant uint N [[function_constant(1)]]; | |
constant uint K [[function_constant(2)]]; | |
// Whether each matrix is transposed. | |
constant bool A_trans = \(transposeState.A); | |
constant bool B_trans = \(transposeState.B); | |
// Define the memory layout of the matrix block. | |
constant ushort M_group = \(blockDimensions.M); | |
constant ushort N_group = \(blockDimensions.N); | |
constant ushort K_group = \(blockDimensions.K); | |
// Thresholds that mark the matrix edge. | |
constant uint M_edge = M - (M % M_group); | |
constant uint N_edge = N - (N % N_group); | |
// Find the number of elements in the final block. If the matrix | |
// dimensions are perfectly divisibly by block dimensions, we don't want | |
// this value to be zero. The final block is a full block. | |
constant ushort M_remainder = (M % \(registerM) == 0) | |
? \(registerM) : M % \(registerM); | |
constant ushort N_remainder = (N % \(registerN) == 0) | |
? \(registerN) : N % \(registerN); | |
constant ushort K_remainder = (K % K_group == 0) | |
? K_group : K % K_group; | |
constant ushort K_remainder_padded = (K_remainder + 7) / 8 * 8; | |
// Shift the final block, so it doesn't access out-of-bounds memory. | |
constant ushort M_shift = (M < M_group) ? 0 : \(registerM) - M_remainder; | |
constant ushort N_shift = (N < N_group) ? 0 : \(registerN) - N_remainder; | |
""" | |
} | |
// Allocate threadgroup memory, using the 'memory precision'. This memory | |
// is allocated at runtime, either by the user (explicit API call) or by | |
// the driver (behind the scenes). | |
func createPrecisionSize(_ precision: GEMMOperandPrecision) -> UInt16 { | |
// NOTE: Exotic precisions like some LLaMA quantization formats and ezm8 | |
// have the exponent deinterleaved from the mantissa. Such precisions | |
// would require careful consideration of the meaning of per-scalar | |
// memory footprint. | |
switch precision { | |
case .FP32: return 4 | |
case .FP16: return 2 | |
case .BF16: return 2 | |
} | |
} | |
// Allocate thread memory, using the 'register precision'. This memory | |
// is allocated by embedding the precision into the assembly code. | |
func createPrecisionName(_ precision: GEMMOperandPrecision) -> String { | |
// Exotic precisions would not require any special handling here. Good | |
// practices dictate that you decode to floating point while filling | |
// up the registers. Therefore, the registers will always be floating | |
// point. | |
switch precision { | |
case .FP32: return "float" | |
case .FP16: return "half" | |
case .BF16: return "bfloat" | |
} | |
} | |
// Determine the names of the operands. | |
let memoryNameA = createPrecisionName(memoryPrecisions.A) | |
let memoryNameB = createPrecisionName(memoryPrecisions.B) | |
let memoryNameC = createPrecisionName(memoryPrecisions.C) | |
let registerNameA = createPrecisionName(registerPrecisions.A) | |
let registerNameB = createPrecisionName(registerPrecisions.B) | |
let registerNameC = createPrecisionName(registerPrecisions.C) | |
// Add the utility functions. | |
source += """ | |
// The layout of threads within a SIMD matrix. | |
// | |
// 0 0 1 1 8 8 9 9 | |
// 2 2 3 3 10 10 11 11 | |
// 4 4 5 5 12 12 13 13 | |
// 6 6 7 7 14 14 15 15 | |
// 16 16 17 17 24 24 25 25 | |
// 18 18 19 19 26 26 27 27 | |
// 20 20 21 21 28 28 29 29 | |
// 22 22 23 23 30 30 31 31 | |
// | |
// This is Morton order, a method for coalescing data accesses. It is used | |
// in a variety of contexts, from ray tracing acceleration structures, to | |
// nodal-point Laplacians, to sorting large lattices of atoms. | |
// | |
// Source: https://patents.google.com/patent/US11256518B2 | |
METAL_FUNC ushort2 morton_order(ushort thread_index_in_simdgroup) { | |
ushort lane_id = thread_index_in_simdgroup; | |
ushort quad_id = lane_id / 4; | |
constexpr ushort QUADRANT_SPAN_M = 4; | |
constexpr ushort THREADS_PER_QUADRANT = 8; | |
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; | |
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); | |
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; | |
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 | |
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 | |
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; | |
return ushort2(N_in_simd, M_in_simd); | |
} | |
// Indexes into an array of registers. | |
// | |
// Calls to this function are expected to be evaluated at compile time. The | |
// array indices transform into register offsets, which are embedded into the | |
// assembly code. | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* get_sram( | |
thread simdgroup_matrix_storage<T> *sram, | |
ushort sram_leading_dim, | |
ushort2 matrix_origin | |
) { | |
return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8); | |
} | |
""" | |
struct MultiplyDescriptor { | |
var addressSpace: String? | |
var leadingDimensionA: String? | |
var leadingDimensionB: String? | |
var loadFunctionA: String? | |
var loadFunctionB: String? | |
} | |
func createMultiply(descriptor: MultiplyDescriptor) -> String { | |
guard let addressSpace = descriptor.addressSpace, | |
let leadingDimensionA = descriptor.leadingDimensionA, | |
let leadingDimensionB = descriptor.leadingDimensionB, | |
let loadFunctionA = descriptor.loadFunctionA, | |
let loadFunctionB = descriptor.loadFunctionB else { | |
fatalError("Descriptor was incomplete.") | |
} | |
return """ | |
// One multiply-accumulate loop iteration, or 8 dot products. | |
METAL_FUNC void multiply_accumulate( | |
const \(addressSpace) \(memoryNameA) *A_src, | |
const \(addressSpace) \(memoryNameB) *B_src, | |
thread simdgroup_matrix_storage<\(registerNameA)> *A_sram, | |
thread simdgroup_matrix_storage<\(registerNameB)> *B_sram, | |
thread simdgroup_matrix_storage<\(registerNameC)> *C_sram, | |
ushort k | |
) { | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < \(registerM); m += 8) { | |
ushort2 origin(0, m); | |
auto A = get_sram(A_sram, 8, origin); | |
A->\(loadFunctionA)(A_src, \(leadingDimensionA), ushort2(k, m), A_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < \(registerN); n += 8) { | |
ushort2 origin(n, 0); | |
auto B = get_sram(B_sram, \(registerN), origin); | |
B->\(loadFunctionB)(B_src, \(leadingDimensionB), ushort2(n, k), B_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < \(registerM); m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < \(registerN); n += 8) { | |
auto A = get_sram(A_sram, 8, ushort2(0, m)); | |
auto B = get_sram(B_sram, \(registerN), ushort2(n, 0)); | |
auto C = get_sram(C_sram, \(registerN), ushort2(n, m)); | |
C->multiply(*A, *B); | |
} | |
} | |
} | |
""" | |
} | |
// Add the utility functions for the multiply-accumulate inner loop. | |
do { | |
var multiplyDesc = MultiplyDescriptor() | |
if memoryPrecisions.A == .BF16, registerPrecisions.A == .FP32 { | |
multiplyDesc.loadFunctionA = "load_bfloat" | |
} else { | |
multiplyDesc.loadFunctionA = "load" | |
} | |
if memoryPrecisions.B == .BF16, registerPrecisions.B == .FP32 { | |
multiplyDesc.loadFunctionB = "load_bfloat" | |
} else { | |
multiplyDesc.loadFunctionB = "load" | |
} | |
multiplyDesc.addressSpace = "device" | |
multiplyDesc.leadingDimensionA = leadingDimensionA | |
multiplyDesc.leadingDimensionB = leadingDimensionB | |
source += createMultiply(descriptor: multiplyDesc) | |
multiplyDesc.addressSpace = "threadgroup" | |
multiplyDesc.leadingDimensionA = "\(leadingBlockDimensionA)" | |
multiplyDesc.leadingDimensionB = "\(leadingBlockDimensionB)" | |
source += createMultiply(descriptor: multiplyDesc) | |
} | |
// Add the setup portion where the addresses are prepared. | |
do { | |
var blockBytesA = paddedBlockDimensionsA.M * paddedBlockDimensionsA.K | |
var blockBytesB = paddedBlockDimensionsB.K * paddedBlockDimensionsB.N | |
var blockBytesC = paddedBlockDimensionsC.M * paddedBlockDimensionsC.N | |
blockBytesA *= createPrecisionSize(memoryPrecisions.A) | |
blockBytesB *= createPrecisionSize(memoryPrecisions.B) | |
blockBytesC *= createPrecisionSize(memoryPrecisions.C) | |
threadgroupMemoryAllocation = max(blockBytesA + blockBytesB, blockBytesC) | |
source += """ | |
// Metal function arguments. | |
// | |
// A: the left-hand side matrix | |
// - dimensions: M x K | |
// K x M (transposed) | |
// - memory precision: memA | |
// - register precision: regA | |
// | |
// B: the right-hand side matrix | |
// - dimensions: K x N | |
// N x K (transposed) | |
// - memory precision: memB | |
// - register precision: regB | |
// | |
// C: the output matrix, alternatively the dot product accumulator | |
// - dimensions: M x N | |
// - memory precision: memC | |
// - register precision: regC | |
// | |
// threadgroup_block: the chunk of threadgroup memory allocated at runtime | |
// - ideally 10 KB or less | |
// - precision: void/8-bit integer to make the pointer arithmetic more legible | |
kernel void gemm(device \(memoryNameA) *A [[buffer(0)]], | |
device \(memoryNameB) *B [[buffer(1)]], | |
device \(memoryNameC) *C [[buffer(2)]], | |
threadgroup uchar *threadgroup_block [[threadgroup(0)]], | |
uint3 gid [[threadgroup_position_in_grid]], | |
ushort sidx [[simdgroup_index_in_threadgroup]], | |
ushort lane_id [[thread_index_in_simdgroup]]) | |
{ | |
auto A_block = (threadgroup \(memoryNameA)*)(threadgroup_block); | |
auto B_block = (threadgroup \(memoryNameB)*)(threadgroup_block + \(blockBytesA)); | |
ushort2 sid(sidx % \(splits.N), sidx / \(splits.N)); | |
ushort2 morton_offset = morton_order(lane_id); | |
// Return early if the SIMD is out of bounds. | |
// | |
// There could be some threadgroups where the matrix edge cuts straight | |
// through the middle of the block. SIMDs on the right or bottom of the | |
// dividing line must be stopped from causing out-of-bounds accesses. This is | |
// the reason for the early exit. | |
uint M_offset = gid.y * M_group; | |
uint N_offset = gid.x * N_group; | |
{ | |
if (M_offset + sid.y * \(registerM) >= M || | |
N_offset + sid.x * \(registerN) >= N) { | |
return; | |
} | |
} | |
ushort2 offset_in_group(sid.x * \(registerN) + morton_offset.x, | |
sid.y * \(registerM) + morton_offset.y); | |
// Shift the matrix block within bounds, if possible. | |
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) { | |
M_offset -= M_shift; | |
} | |
if ((N_shift != 0) && (gid.x * N_group >= N_edge)) { | |
N_offset -= N_shift; | |
} | |
""" | |
} | |
// Add the setup of the accumulator. | |
do { | |
let arrayElementsC: UInt16 = (registerM / 8) * (registerN / 8) | |
source += """ | |
simdgroup_matrix_storage<\(registerNameC)> C_sram[\(arrayElementsC)]; | |
// Initialize the accumulator. | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < \(registerM); m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < \(registerN); n += 8) { | |
ushort2 origin(n, m); | |
auto C = get_sram(C_sram, \(registerN), origin); | |
*C = simdgroup_matrix_storage<\(registerNameC)>(0); | |
} | |
} | |
""" | |
} | |
// Add the matrix multiplication iterations. | |
// | |
// Async copies are required for correct behavior in edge cases. We attempt | |
// to execute most iterations without async copy, and only the necessary | |
// ones with async copy. | |
do { | |
var asyncIterationsStart: String | |
if descriptor.preferAsyncLoad { | |
asyncIterationsStart = "0" | |
} else { | |
asyncIterationsStart = "(K - (K % K_group))" | |
} | |
let paddedCeilingK = "(K + K_remainder_padded - K_remainder)" | |
source += """ | |
// Perform the iterations where async copy is avoided. | |
for (uint k = 0; k < \(asyncIterationsStart); k += 8) { | |
uint2 A_offset(k, M_offset); | |
uint2 B_offset(N_offset, k); | |
A_offset += uint2(morton_offset.x, offset_in_group.y); | |
B_offset += uint2(offset_in_group.x, morton_offset.y); | |
auto A_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset( | |
A, \(leadingDimensionA), A_offset, A_trans); | |
auto B_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset( | |
B, \(leadingDimensionB), B_offset, B_trans); | |
simdgroup_matrix_storage<\(registerNameA)> A_sram[\(registerM / 8) * (8 / 8)]; | |
simdgroup_matrix_storage<\(registerNameB)> B_sram[(8 / 8) * \(registerN / 8)]; | |
multiply_accumulate(A_src, B_src, | |
A_sram, B_sram, C_sram, 0); | |
} | |
// Perform the iterations where async copy is used. | |
for (uint k = \(asyncIterationsStart); k < K; k += K_group) { | |
// Launch an async copy from device to threadgroup memory. | |
if (sidx == 0) { | |
uint2 A_offset(k, M_offset); | |
uint2 B_offset(N_offset, k); | |
auto A_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset( | |
A, \(leadingDimensionA), A_offset, A_trans); | |
auto B_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset( | |
B, \(leadingDimensionB), B_offset, B_trans); | |
ushort M_tile_dimension = min(uint(M_group), M - M_offset); | |
ushort N_tile_dimension = min(uint(N_group), N - N_offset); | |
ushort K_tile_dimension = min(uint(K_group), K - k); | |
ushort K_tile_padded = min(uint(K_group), \(paddedCeilingK) - k); | |
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension); | |
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension); | |
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension); | |
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded); | |
simdgroup_event events[2]; | |
events[0].async_copy(A_block, \(leadingBlockDimensionA), A_tile_dst, | |
A_src, \(leadingDimensionA), A_tile_src, A_trans); | |
events[1].async_copy(B_block, \(leadingBlockDimensionB), B_tile_dst, | |
B_src, \(leadingDimensionB), B_tile_src, B_trans); | |
simdgroup_event::wait(2, events); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
ushort2 A_block_offset(morton_offset.x, offset_in_group.y); | |
ushort2 B_block_offset(offset_in_group.x, morton_offset.y); | |
auto A_block_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset( | |
A_block, \(leadingBlockDimensionA), A_block_offset, A_trans); | |
auto B_block_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset( | |
B_block, \(leadingBlockDimensionB), B_block_offset, B_trans); | |
simdgroup_matrix_storage<\(registerNameA)> A_sram[\(registerM / 8) * (K_group / 8)]; | |
simdgroup_matrix_storage<\(registerNameB)> B_sram[(K_group / 8) * \(registerN / 8)]; | |
#pragma clang loop unroll(full) | |
for (ushort k = 0; k < K_remainder_padded; k += 8) { | |
multiply_accumulate(A_block_src, B_block_src, | |
A_sram, B_sram, C_sram, k); | |
} | |
// Will there be any iterations after this one? | |
if (k + K_group < K) { | |
// If so, we haven't reached the edge of either input matrix yet. | |
#pragma clang loop unroll(full) | |
for (ushort k = K_remainder_padded; k < K_group; k += 8) { | |
multiply_accumulate(A_block_src, B_block_src, | |
A_sram, B_sram, C_sram, k); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
} | |
} | |
""" | |
} | |
// Add the cleanup portion where the accumulator is stored. | |
do { | |
var storeFunctionC: String | |
if memoryPrecisions.C == .BF16, | |
registerPrecisions.C == .FP32 { | |
storeFunctionC = "store_bfloat" | |
} else { | |
storeFunctionC = "store" | |
} | |
var condition: String | |
if preferAsyncStore { | |
condition = "false" | |
} else { | |
condition = "(M >= M_group) && (N >= N_group)" | |
} | |
source += """ | |
if (\(condition)) { | |
// Fast path for matrices that qualify. | |
uint2 C_offset(N_offset + offset_in_group.x, | |
M_offset + offset_in_group.y); | |
auto C_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset( | |
C, N, C_offset); | |
// Write the accumulator to device memory. | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < \(registerM); m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < \(registerN); n += 8) { | |
ushort2 origin(n, m); | |
auto C = get_sram(C_sram, \(registerN), origin); | |
C->\(storeFunctionC)(C_dst, N, origin); | |
} | |
} | |
} else { | |
// Slow path for when memory must be handled more carefully. | |
auto C_block = (threadgroup \(memoryNameC)*)(threadgroup_block); | |
auto C_block_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset( | |
C_block, N_group, offset_in_group); | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// Write the accumulator to threadgroup memory. | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < \(registerM); m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < \(registerN); n += 8) { | |
ushort2 origin(n, m); | |
auto C = get_sram(C_sram, \(registerN), origin); | |
C->\(storeFunctionC)(C_block_dst, N_group, origin); | |
} | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// Launch the async copy from threadgroup to device memory. | |
if (sidx == 0) { | |
uint2 C_offset(gid.x * N_group, gid.y * M_group); | |
ushort2 C_tile(min(uint(N_group), N - C_offset.x), | |
min(uint(M_group), M - C_offset.y)); | |
auto C_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset( | |
C, N, C_offset); | |
// If we shift successfully, the garbage zone moves from the bottom right | |
// to the top left. | |
if ((M_shift != 0) || (N_shift != 0)) { | |
ushort2 C_block_shift(0, 0); | |
if ((M_shift != 0) && (C_offset.y >= M_edge)) { | |
C_block_shift.y = M_shift; | |
} | |
if ((N_shift != 0) && (C_offset.x >= N_edge)) { | |
C_block_shift.x = N_shift; | |
} | |
C_block = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset( | |
C_block, N_group, C_block_shift); | |
} | |
simdgroup_event event; | |
event.async_copy(C_dst, N, C_tile, C_block, N_group, C_tile); | |
} | |
} | |
""" | |
} | |
// Add the final closing brace of the Metal function. | |
source += "}" + "\n" | |
// Compile the shader source. | |
library = try! device.makeLibrary(source: source, options: nil) | |
} | |
} | |
// MARK: - Header Sources | |
/// Create the source code for the 'metal\_simdgroup\_event' header. | |
/// | |
/// I may have found the hardware bug with async copies on M1. If you shoot | |
/// off an async copy, you need to read from its contents later in the | |
/// the shader. Otherwise, something inside the hardware (like a | |
/// DispatchSemaphore) will be waiting indefinitely to be notified. The bug | |
/// is a bit flaky, and only shows up for certain problem configurations. The | |
/// side effects are catastrophic; the GPU might freeze up until the computer | |
/// reboots. | |
/// | |
/// Workaround: if an async copy from device -> threadgroup is launched, | |
/// guarantee that both: | |
/// - The threadgroup will enter another `threadgroup_barrier` before the end of | |
/// the kernel. | |
/// - The results of the async copy will be read from. This means at least one | |
/// thread must dereference a pointer within the region of threadgroup memory. | |
func createMetalSimdgroupEvent() -> String { | |
// Return the source string. | |
return """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_event ---------------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_EVENT | |
#define __METAL_SIMDGROUP_EVENT | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %struct._simdgroup_event_t = type opaque | |
// | |
struct _simdgroup_event_t; | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p3i8.p1i8( | |
// i64, i64, | |
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>, | |
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>, | |
// <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, | |
threadgroup void *, ulong, ulong, ulong2, | |
const device void *, ulong, ulong, ulong2, | |
long2, int) | |
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p1i8.p3i8( | |
// i64, i64, | |
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>, | |
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>, | |
// <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, | |
device void *, ulong, ulong, ulong2, | |
const threadgroup void *, ulong, ulong, ulong2, | |
long2, int) | |
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: convergent nounwind | |
// declare void | |
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) | |
// local_unnamed_addr #3 | |
// | |
void __metal_wait_simdgroup_events( | |
int, thread _simdgroup_event_t**) | |
__asm("air.wait_simdgroup_events"); | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
enum class simdgroup_async_copy_clamp_mode { | |
clamp_to_zero = 0, | |
clamp_to_edge = 1 | |
}; | |
struct simdgroup_event { | |
METAL_FUNC simdgroup_event() thread {} | |
template <typename T> | |
METAL_FUNC void async_copy( | |
// Description of the destination. | |
threadgroup T *dst, | |
ushort dst_elements_per_row, | |
ushort2 dst_tile_dimensions, | |
// Description of the source. | |
const device T *src, | |
uint src_elements_per_row, | |
ushort2 src_tile_dimensions, | |
// Other arguments. | |
bool transpose_matrix = false, | |
simdgroup_async_copy_clamp_mode clamp_mode = | |
simdgroup_async_copy_clamp_mode::clamp_to_zero | |
) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d( | |
// Description of the data type. | |
sizeof(T), | |
alignof(T), | |
// Description of the destination. | |
reinterpret_cast<threadgroup void *>(dst), | |
ushort(dst_elements_per_row), | |
1, | |
ulong2(dst_tile_dimensions), | |
// Description of the source. | |
reinterpret_cast<const device void *>(src), | |
uint(src_elements_per_row), | |
1, | |
ulong2(src_tile_dimensions), | |
// Other arguments. | |
long2(0), | |
static_cast<int>(clamp_mode)); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy( | |
// Description of the destination. | |
device T *dst, | |
uint dst_elements_per_row, | |
ushort2 dst_tile_dimensions, | |
// Description of the source. | |
const threadgroup T *src, | |
ushort src_elements_per_row, | |
ushort2 src_tile_dimensions, | |
// Other arguments. | |
bool transpose_matrix = false | |
) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d( | |
// Description of the data type. | |
sizeof(T), | |
alignof(T), | |
// Description of the destination. | |
reinterpret_cast<device void *>(dst), | |
uint(dst_elements_per_row), | |
1, | |
ulong2(dst_tile_dimensions), | |
// Description of the source. | |
reinterpret_cast<const threadgroup void *>(src), | |
ushort(src_elements_per_row), | |
1, | |
ulong2(src_tile_dimensions), | |
// Other arguments. | |
long2(0), | |
0); | |
} | |
METAL_FUNC static void wait(int count, thread simdgroup_event *events) { | |
__metal_wait_simdgroup_events( | |
count, reinterpret_cast<thread _simdgroup_event_t**>(events)); | |
} | |
private: | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } | |
// | |
thread _simdgroup_event_t* event; | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif // __METAL_SIMDGROUP_EVENT | |
""" | |
} | |
/// Create the source code for the 'metal\_simdgroup\_matrix\_storage' header. | |
func createMetalSimdgroupMatrixStorage() -> String { | |
// How this header spawning code was designed. | |
// | |
// Find the patterns between the load/store functions: | |
// - device has 'uint' elements_per_row | |
// - threadgroup has 'ushort' elements_per_row | |
// - both have 'ushort2' matrix_origin | |
// | |
// The origin is 'ushort2' because the 32-bit part of the address should have | |
// been applied previously during 'apply_offset'. The 16-bit part should be | |
// hard-coded into the assembly when the GEMM loop is unrolled. | |
// | |
// Transpose path: | |
// - load: reads two values; should split each one onto a separate line. | |
// - overwrites the value of *thread_elements() with a new vec<T, 2> | |
// - store: the two instructions are on two separate lines. | |
// - fetches from lane 0 or 1 of thread_elements()[0] | |
// - adds 0 or 1 to the hard-coded matrix_origin.x | |
// | |
// Address generation: | |
// - casts some intermediate address fragments to 'ulong' for 'device' | |
// - keeps all address fragments in 'ushort' for 'threadgroup' | |
enum AddressSpace { | |
case device | |
case threadgroup | |
var keyword: String { | |
switch self { | |
case .device: return "device" | |
case .threadgroup: return "threadgroup" | |
} | |
} | |
var offsetType: String { | |
switch self { | |
case .device: return "uint" | |
case .threadgroup: return "ushort" | |
} | |
} | |
} | |
enum Action { | |
case load | |
case store | |
} | |
struct MemoryAccessDescriptor { | |
var action: Action? | |
var addressSpace: AddressSpace? | |
var decodingBF16: Bool? | |
var indentationSpaceCount: Int = .zero | |
} | |
func createMemoryAccess( | |
descriptor: MemoryAccessDescriptor | |
) -> String { | |
guard let action = descriptor.action, | |
let addressSpace = descriptor.addressSpace, | |
let decodingBF16 = descriptor.decodingBF16 else { | |
fatalError("Descriptor was incomplete.") | |
} | |
let indentation = String( | |
repeating: " ", count: descriptor.indentationSpaceCount) | |
// Determine the arguments. | |
var arguments: [String] = [] | |
func addPointerArgument(dataType: String) { | |
if action == .load { | |
arguments.append("const \(addressSpace.keyword) \(dataType) *src") | |
} else { | |
arguments.append("\(addressSpace.keyword) \(dataType) *dst") | |
} | |
} | |
if decodingBF16 { | |
addPointerArgument(dataType: "bfloat") | |
} else { | |
addPointerArgument(dataType: "U") | |
} | |
arguments.append("\(addressSpace.offsetType) elements_per_row") | |
arguments.append("ushort2 matrix_origin") | |
arguments.append("bool transpose_matrix = false") | |
// Create the warning comment. | |
var output: String = "" | |
if decodingBF16 { | |
output += "\(indentation)// WARNING: 'T' must be 'float'.\n" | |
} else { | |
output += "\(indentation)template <typename U>\n" | |
} | |
// Create the function signature. | |
output += "\(indentation)METAL_FUNC void" | |
if action == .load { | |
output += " load" | |
} else { | |
output += " store" | |
} | |
if decodingBF16 { | |
output += "_bfloat" | |
} | |
output += "(" | |
for argumentID in arguments.indices { | |
let argument = arguments[argumentID] | |
output += argument | |
if argumentID < arguments.count - 1 { | |
output += ", " | |
} | |
} | |
output += ") {\n" | |
func createAddress(transposed: Bool, offset: Int) -> String { | |
let lineY = "\(addressSpace.offsetType)(matrix_origin.y)" | |
var lineX = "matrix_origin.x + \(offset)" | |
lineX = "\(addressSpace.offsetType)(\(lineX))" | |
if transposed { | |
return "\(lineX) * elements_per_row + \(lineY)" | |
} else { | |
return "\(lineY) * elements_per_row + \(lineX)" | |
} | |
} | |
func createTwoPartAccess(transposed: Bool) -> [String] { | |
// Generate the addresses. | |
var lines: [String] = [] | |
for laneID in 0..<2 { | |
lines.append( | |
"\(addressSpace.offsetType) address\(laneID) = " + | |
createAddress(transposed: transposed, offset: laneID)) | |
} | |
if action == .load { | |
if decodingBF16 { | |
lines.append("bfloat memoryForm0 = src[address0]") | |
lines.append("bfloat memoryForm1 = src[address1]") | |
} else { | |
lines.append("U memoryForm0 = src[address0]") | |
lines.append("U memoryForm1 = src[address1]") | |
} | |
} | |
if action == .load { | |
if decodingBF16 { | |
// Separate the loading logic from the decoding logic for clarity. | |
lines.append( | |
"") | |
// BF16 decoding logic. | |
lines.append( | |
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())") | |
lines.append( | |
"registerForm[1] = memoryForm0") | |
lines.append( | |
"registerForm[3] = memoryForm1") | |
lines.append( | |
"((thread bfloat4*)thread_elements())[0] = registerForm") | |
} else { | |
// Perform a type cast natively supported by the hardware. | |
lines.append("((thread T*)thread_elements())[0] = T(memoryForm0)") | |
lines.append("((thread T*)thread_elements())[1] = T(memoryForm1)") | |
} | |
} else { | |
if decodingBF16 { | |
// BF16 encoding logic. | |
lines.append( | |
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())") | |
lines.append( | |
"registerForm[2] = registerForm[1]") | |
} else { | |
// Type casts supported natively by the hardware. | |
lines.append("T registerForm0 = ((thread T*)thread_elements())[0]") | |
lines.append("T registerForm1 = ((thread T*)thread_elements())[1]") | |
} | |
} | |
if action == .store { | |
if decodingBF16 { | |
lines.append("dst[address0] = registerForm[2]") | |
lines.append("dst[address1] = registerForm[3]") | |
} else { | |
lines.append("dst[address0] = U(registerForm0)") | |
lines.append("dst[address1] = U(registerForm1)") | |
} | |
} | |
return lines | |
} | |
func createOnePartAccess() -> [String] { | |
var lines: [String] = [] | |
do { | |
let address = createAddress(transposed: false, offset: 0) | |
lines.append("auto combinedAddress = \(address)") | |
} | |
if action == .load { | |
if decodingBF16 { | |
lines.append( | |
"bfloat2 memoryForm = " + | |
"*(const \(addressSpace.keyword) packed_bfloat2*)(src + combinedAddress)") | |
// Separate the loading logic from the decoding logic for clarity. | |
lines.append( | |
"") | |
// BF16 decoding logic. | |
lines.append( | |
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())") | |
lines.append( | |
"((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm)") | |
lines.append( | |
"((thread bfloat*)®isterForm)[1] = memoryForm[0]") | |
lines.append( | |
"((thread bfloat4*)thread_elements())[0] = registerForm") | |
} else { | |
lines.append( | |
"vec<U, 2> memoryForm = " + | |
"*(const \(addressSpace.keyword) vec<U, 2>*)(src + combinedAddress)") | |
lines.append( | |
"*(thread_elements()) = vec<T, 2>(memoryForm)") | |
} | |
} else { | |
if decodingBF16 { | |
// BF16 encoding logic. | |
lines.append( | |
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())") | |
lines.append( | |
"registerForm[2] = registerForm[1]") | |
lines.append( | |
"float memoryForm = ((thread float*)®isterForm)[1]") | |
lines.append( | |
"*(\(addressSpace.keyword) float*)(dst + combinedAddress) = " + | |
"memoryForm") | |
} else { | |
lines.append( | |
"vec<T, 2> registerForm = *(thread_elements())") | |
lines.append( | |
"*(\(addressSpace.keyword) vec<U, 2>*)(dst + combinedAddress) = " + | |
"vec<U, 2>(registerForm)") | |
} | |
} | |
return lines | |
} | |
func addBlockContents(_ block: [String]) -> [String] { | |
block.map { | |
if $0.allSatisfy(\.isWhitespace) { | |
return " " | |
} else { | |
return " \($0);" | |
} | |
} | |
} | |
// Determine the lines of the 'if' block. | |
var body: [String] = [] | |
body.append("if (transpose_matrix) {") | |
body += addBlockContents(createTwoPartAccess(transposed: true)) | |
// Determine the lines of the 'else' block. | |
if decodingBF16 { | |
var blockContents: [String] | |
if action == .load { | |
blockContents = createOnePartAccess() | |
} else { | |
blockContents = createTwoPartAccess(transposed: false) | |
} | |
body.append("} else {") | |
body += addBlockContents(blockContents) | |
body.append("}") | |
} else { | |
body.append("} else if (elements_per_row % 2 != 0) {") | |
body += addBlockContents(createTwoPartAccess(transposed: false)) | |
body.append("} else {") | |
body += addBlockContents(createOnePartAccess()) | |
body.append("}") | |
} | |
// Create the function body. | |
for line in body { | |
output += "\(indentation) \(line)\n" | |
} | |
output += "\(indentation)}\n" | |
return output | |
} | |
// Add the first section of the shader. | |
var output: String = "" | |
output += """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_matrix_storage ------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE | |
#define __METAL_SIMDGROUP_MATRIX_STORAGE | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
template <typename T> | |
struct simdgroup_matrix_storage { | |
typedef vec<T, 64> storage_type; | |
storage_type t; | |
METAL_FUNC thread vec<T, 2>* thread_elements() thread { | |
return reinterpret_cast<thread vec<T, 2>*>(&t); | |
} | |
METAL_FUNC simdgroup_matrix_storage() thread = default; | |
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread { | |
*(this->thread_elements()) = thread_elements; | |
} | |
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; | |
} else { | |
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; | |
} | |
} | |
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + matrix_origin.x * elements_per_row + matrix_origin.y; | |
} else { | |
return src + matrix_origin.y * elements_per_row + matrix_origin.x; | |
} | |
} | |
""" | |
var desc = MemoryAccessDescriptor() | |
desc.indentationSpaceCount = 4 | |
for action in [Action.load, .store] { | |
for addressSpace in [AddressSpace.device, .threadgroup] { | |
for decodingBF16 in [false, true] { | |
desc.action = action | |
desc.addressSpace = addressSpace | |
desc.decodingBF16 = decodingBF16 | |
output += createMemoryAccess(descriptor: desc) | |
output += "\n" | |
} | |
} | |
} | |
// Add the last section of the header. | |
output += """ | |
template <typename U, typename V> | |
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) { | |
if (!accumulate) { | |
*(thread_elements()) = vec<T, 2>(0); | |
} | |
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type()); | |
} | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE | |
""" | |
return output | |
} | |
// MARK: - Implementation of Configuration Selection Logic | |
/// A description of a dense matrix-matrix multiplication. | |
struct GEMMDescriptor { | |
/// The number of equally sized multiplications that run in parallel. | |
/// Batching is out of scope for the reference implementation. However, there | |
/// should be a guide for clients that wish to modify the shader, in ways | |
/// that increase the compute workload. For example, by batching the | |
/// multiplication of (sub)matrices located at arbitrary pointers in memory | |
/// (with potentially nonuniform stride or noncontiguous padding). | |
var batchDimension: Int = 1 | |
/// The dimensions of the input and output matrices. | |
/// - Parameter M: Number of output columns. | |
/// - Parameter N: Number of output rows. | |
/// - Parameter K: Number of loop iterations for the dot products. | |
/// | |
/// For all practical purposes, one can assume matrix dimensions are 32-bit. | |
/// I use this quite often in other code. The pointers themselves are 64-bit, | |
/// but the offsets between different elements are 32-bit. With 4-byte words, | |
/// this scheme could access up to 16 GB of memory - larger than any array | |
/// in any reasonable application. Handling larger allocations likely | |
/// requires consideration of more failure points than just integer | |
/// overflows. | |
var matrixDimensions: (M: UInt32, N: UInt32, K: UInt32)? | |
var memoryPrecisions: ( | |
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)? | |
var transposeState: (A: Bool, B: Bool)? | |
} | |
struct GEMMKey: Equatable, Hashable { | |
var batchDimension: Int | |
var matrixDimensions: SIMD3<UInt32> | |
var memoryPrecisions: SIMD3<UInt16> | |
var transposeState: SIMD2<UInt8> | |
init(copying source: GEMMDescriptor) { | |
batchDimension = source.batchDimension | |
matrixDimensions = Self.createMatrixDimensions(source.matrixDimensions) | |
memoryPrecisions = GEMMKernelKey.createPrecisions(source.memoryPrecisions) | |
transposeState = GEMMKernelKey.createTransposeState(source.transposeState) | |
} | |
@_transparent // performance in -Ounchecked | |
static func createMatrixDimensions( | |
_ input: (UInt32, UInt32, UInt32)? | |
) -> SIMD3<UInt32> { | |
if let input { | |
return SIMD3(input.0, input.1, input.2) | |
} else { | |
return SIMD3(repeating: .max) | |
} | |
} | |
} | |
extension GEMMDescriptor: Hashable, Equatable { | |
static func == (lhs: GEMMDescriptor, rhs: GEMMDescriptor) -> Bool { | |
let lhsKey = GEMMKey(copying: lhs) | |
let rhsKey = GEMMKey(copying: rhs) | |
return lhsKey == rhsKey | |
} | |
func hash(into hasher: inout Hasher) { | |
let key = GEMMKey(copying: self) | |
hasher.combine(key) | |
} | |
} | |
extension GEMMKernelDescriptor { | |
/// Initialize the kernel descriptor using another descriptor, which just | |
/// specifies the problem size. Then, forget the information about problem | |
/// size. It will not be needed until the very far future, when the user | |
/// retrieves a `MTLLibrary` from the cache and sets some Metal function | |
/// constants. | |
/// | |
/// One might initialize a `GEMMKernelDescriptor` this way whenever an | |
/// arbitrary matrix multiplication is requested. The generated descriptor | |
/// itself could be a key in the KV cache. With this shader cache design, you | |
/// must minimize the latency of actions like `MTLDevice` materialization and | |
/// core count queries. | |
/// | |
/// Acceptable latency: no more than 1 μs per invocation. | |
init(descriptor: GEMMDescriptor) { | |
guard let matrixDimensions = descriptor.matrixDimensions, | |
let memoryPrecisions = descriptor.memoryPrecisions, | |
let transposeState = descriptor.transposeState else { | |
fatalError("Descriptor was incomplete.") | |
} | |
// Select the only GPU on an Apple silicon system. | |
// | |
// NOTE: To avoid potentially costly API calls, you may wish to cache the | |
// MTLDevice object or enter a previously created one. The core count | |
// could also be cached on macOS. | |
// | |
// Typical latency to initiate a Metal device, provided the function has | |
// been called numerous times prior: | |
// - macOS 14 | |
// - Swift debug mode, Metal API validation on: ≥33 μs | |
// - Swift release mode, Metal API validation off: ≥38 μs | |
// - iOS 17 | |
// - Swift debug mode, Metal API validation on: ≥0 μs | |
// - Swift release mode, Metal API validation off: ≥0 μs | |
let mtlDevice = MTLCreateSystemDefaultDevice()! | |
// Trim the device name to something easier to process. | |
// | |
// M1 Max: Apple M1 Max -> M1 | |
// M4: Apple M4 GPU -> M4 | |
func createDeviceName() -> String { | |
let deviceName = mtlDevice.name | |
var splits = deviceName.split(separator: " ").map(String.init) | |
splits.removeAll(where: { $0.starts(with: "Apple") }) | |
splits.removeAll(where: { $0.starts(with: "GPU") }) | |
// Iterate over the space-separated words. | |
var matchingSplitIDs: [UInt32] = [] | |
for splitID in splits.indices { | |
// Screen out obvious non-candidates. | |
let split = splits[splitID] | |
guard split.starts(with: "A") || split.starts(with: "M") else { | |
continue | |
} | |
// Extract the second character. | |
guard split.count > 1 else { | |
continue | |
} | |
let secondCharacterInt8 = split.utf8CString[1] | |
let secondCharacterUInt32 = UInt32(secondCharacterInt8) | |
let secondCharacterUnicode = Unicode.Scalar(secondCharacterUInt32)! | |
let secondCharacter = Character(secondCharacterUnicode) | |
// If the second character is numeric, the candidate passes. | |
if secondCharacter.isWholeNumber { | |
matchingSplitIDs.append(UInt32(splitID)) | |
} | |
} | |
guard matchingSplitIDs.count == 1 else { | |
fatalError("Failed to locate device name.") | |
} | |
let splitID = matchingSplitIDs[0] | |
return splits[Int(splitID)] | |
} | |
let deviceName = createDeviceName() | |
// Find the core count. | |
#if os(macOS) | |
// Typical latency to query IORegistry, provided the function has been | |
// called numerous times prior: | |
// - macOS 14 | |
// - Swift debug mode, Metal API validation on: ≥9 μs | |
// - Swift release mode, Metal API validation off: ≥9 μs | |
let coreCount = findCoreCount() | |
#elseif os(iOS) | |
var coreCount: Int | |
if deviceName.starts(with: "A") { | |
if mtlDevice.supportsFamily(.apple9) { | |
coreCount = 6 | |
} else { | |
coreCount = 5 | |
} | |
} else { | |
coreCount = 10 | |
} | |
#endif | |
// Select the register precisions. | |
var registerPrecisionA = memoryPrecisions.A | |
var registerPrecisionB = memoryPrecisions.B | |
var registerPrecisionC = GEMMOperandPrecision.FP32 | |
if memoryPrecisions.A == .FP16, | |
memoryPrecisions.B == .FP16, | |
memoryPrecisions.C == .FP16 { | |
// If FP16 is causing accuracy issues, you can change this to FP32. Note | |
// that doing so cuts out a very important part of the performance | |
// spectrum. It is only FP16xFP16->FP16 that reaches peak performance. | |
// This statement applies to both the M1 and M3 architectures. | |
// | |
// FP16xFP16 into FP16 accumulator triggers this instruction: | |
// https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3232-L3244 | |
// | |
// FP16xFP16/BF16xBF16 into FP32 accumulator triggers this instruction: | |
// https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3195-L3207 | |
// | |
// No other input/output register types map to a native instruction. | |
// | |
// I would recommend changing the accumulator precision on a case-by-case | |
// (operation-by-operation) basis. Provide some mechanism in the high-level | |
// API, to control certain low-level features. Without harming execution | |
// latency and without imposing technical debt on the high-level API. | |
// Definitely NOT a global flag that forces all matrices to change from | |
// FP16 -> FP32. | |
registerPrecisionC = GEMMOperandPrecision.FP16 | |
} | |
if !mtlDevice.supportsFamily(.apple9) { | |
if memoryPrecisions.A == .BF16 { | |
registerPrecisionA = .FP32 | |
} | |
if memoryPrecisions.B == .BF16 { | |
registerPrecisionB = .FP32 | |
} | |
} | |
// Set the properties of the 'GEMMKernelDescriptor' object. | |
self.memoryPrecisions = memoryPrecisions | |
if mtlDevice.supportsFamily(.apple9) { | |
self.preferAsyncLoad = false | |
} else { | |
self.preferAsyncLoad = true | |
} | |
self.registerPrecisions = ( | |
registerPrecisionA, | |
registerPrecisionB, | |
registerPrecisionC) | |
if !mtlDevice.supportsFamily(.apple9) { | |
self.splits = (2, 2) | |
} else { | |
self.splits = (1, 1) | |
} | |
self.transposeState = transposeState | |
// Set the properties that deal with block size. | |
setBlockDimensions( | |
mtlDevice: mtlDevice, | |
coreCount: coreCount, | |
matrixDimensions: matrixDimensions, | |
batchDimension: descriptor.batchDimension) | |
} | |
// Implementation of the block size selection heuristic. | |
// | |
// This function initializes the 'blockDimensions' and | |
// 'paddedBlockDimensions' properties. | |
private mutating func setBlockDimensions( | |
mtlDevice: MTLDevice, | |
coreCount: Int, | |
matrixDimensions: (M: UInt32, N: UInt32, K: UInt32), | |
batchDimension: Int | |
) { | |
guard let memoryPrecisions, | |
let transposeState else { | |
fatalError("Some properties were not set.") | |
} | |
guard !mtlDevice.supportsFamily(.apple9) else { | |
self.blockDimensions = (32, 32, 8) | |
return | |
} | |
// Find the actual number of threadgroups, with a large block size. | |
func ceilDivide(_ target: UInt32, _ granularity: UInt16) -> UInt32 { | |
(target + UInt32(granularity) - 1) / UInt32(granularity) | |
} | |
var actualGroups: Int = 1 | |
actualGroups *= Int(ceilDivide(matrixDimensions.M, 48)) | |
actualGroups *= Int(ceilDivide(matrixDimensions.N, 48)) | |
actualGroups *= Int(batchDimension) | |
// Does the kernel use 48x48x24xFP32 (9 KB) or 48x48x32xFP16/BF16 (6 KB)? | |
func requiresLargeAllocation(_ precision: GEMMOperandPrecision) -> Bool { | |
switch precision { | |
case .FP32: return true | |
case .FP16: return false | |
case .BF16: return false | |
} | |
} | |
var useLargeAllocation = false | |
if requiresLargeAllocation(memoryPrecisions.A) { | |
useLargeAllocation = true | |
} | |
if requiresLargeAllocation(memoryPrecisions.B) { | |
useLargeAllocation = true | |
} | |
if requiresLargeAllocation(memoryPrecisions.C) { | |
useLargeAllocation = true | |
} | |
// Branch on whether the allocation is large / target occupancy is low. | |
if useLargeAllocation { | |
let idealGroups = coreCount * 6 | |
if actualGroups <= idealGroups { | |
self.blockDimensions = (32, 32, 32) | |
} else { | |
self.blockDimensions = (48, 48, 24) | |
// This is verified to be optimal for: | |
// - (memA, memB, memC) = (FP32, FP32, FP32) | |
// - (memA, memB, memC) = (FP16, FP16, FP32) | |
// - (memA, memB, memC) = (FP16, FP32, FP32) | |
// - (memA, memB, memC) = (FP16, FP32, FP16) | |
switch transposeState { | |
case (false, false): | |
self.paddedBlockDimensions = ((48, 24), (24, 48), (48, 48)) | |
case (false, true): | |
let paddedBK = (memoryPrecisions.B == .FP32) ? UInt16(28) : 24 | |
self.paddedBlockDimensions = ((48, 24), (paddedBK, 48), (48, 48)) | |
case (true, false): | |
let paddedAM = (memoryPrecisions.A == .FP32) ? UInt16(52) : 56 | |
self.paddedBlockDimensions = ((paddedAM, 24), (24, 48), (48, 48)) | |
case (true, true): | |
let paddedAM = (memoryPrecisions.A == .FP32) ? UInt16(52) : 56 | |
self.paddedBlockDimensions = ((paddedAM, 24), (24, 48), (48, 48)) | |
} | |
} | |
} else { | |
let idealGroups = coreCount * 9 | |
if actualGroups <= idealGroups { | |
blockDimensions = (32, 32, 32) | |
} else { | |
blockDimensions = (48, 48, 32) | |
} | |
} | |
} | |
} | |
#if os(macOS) | |
/// Finds the core count on macOS devices, using IORegistry. | |
/// | |
/// Source: [AppleGPUInfo](https://github.com/philipturner/applegpuinfo) | |
/// | |
/// This code was generated by GPT-4 a few days after launch (early 2023). | |
/// Since then, it has undergone extensive human review and real-world testing. | |
/// It proved that proto-AGI could be a practically useful tool, in this case | |
/// assisting with code creation. | |
func findCoreCount() -> Int { | |
// Create a matching dictionary with "AGXAccelerator" class name | |
let matchingDict = IOServiceMatching("AGXAccelerator") | |
// Get an iterator for matching services | |
var iterator: io_iterator_t = 0 | |
do { | |
let io_registry_error = | |
IOServiceGetMatchingServices( | |
kIOMainPortDefault, matchingDict, &iterator) | |
guard io_registry_error == 0 else { | |
fatalError( | |
"Encountered IORegistry error code \(io_registry_error)") | |
} | |
} | |
// Get the first (and only) GPU entry from the iterator | |
let gpuEntry = IOIteratorNext(iterator) | |
// Check if the entry is valid | |
if gpuEntry == MACH_PORT_NULL { | |
fatalError( | |
"Error getting GPU entry at \(#file):\(#line - 5)") | |
} | |
// Release the iterator | |
IOObjectRelease(iterator) | |
// Get the "gpu-core-count" property from gpuEntry | |
let key = "gpu-core-count" | |
let options: IOOptionBits = 0 // No options needed | |
let gpuCoreCount = IORegistryEntrySearchCFProperty( | |
gpuEntry, kIOServicePlane, key as CFString, nil, options) | |
// Check if the property is valid | |
if gpuCoreCount == nil { | |
fatalError( | |
"Error getting gpu-core-count property at \(#file):\(#line - 6)") | |
} | |
// Cast the property to CFNumberRef | |
let gpuCoreCountNumber = gpuCoreCount as! CFNumber | |
// Check if the number type is sInt64 | |
let type = CFNumberGetType(gpuCoreCountNumber) | |
if type != .sInt64Type { | |
fatalError( | |
"Error: gpu-core-count is not sInt64 at \(#file):\(#line - 3)") | |
} | |
// Get the value of the number as Int64 | |
var value: Int64 = 0 | |
let result = CFNumberGetValue(gpuCoreCountNumber, type, &value) | |
// Check for errors | |
if result == false { | |
fatalError( | |
" Error getting value of gpu-core-count at \(#file):\(#line - 5)") | |
} | |
return Int(value) | |
} | |
#endif | |
/// A reference implementation of shader caching. | |
/// | |
/// One good design for a shader caching mechanism: | |
/// - Two key-value caches. | |
/// - The first caches `MTLLibrary` objects. | |
/// - Large latency | |
/// - Small number of combinatorial possibilities, likely to be shared by | |
/// matrices with a different size. | |
/// - Don't bother with serializing Metal binary archives to disk. You are | |
/// already utilizing the system-wide Metal shader cache. | |
/// - The second caches `MTLComputePipelineState` objects. | |
/// - Instantiations of the `MTLLibrary` with different function constants. | |
/// - Less latency than compiling from source, but still non-negligible. You | |
/// can't spawn a new PSO during every call to a matrix multiplication. | |
extension GEMMKernel { | |
/// WARNING: Not thread safe. But will the DSL interpreter even use | |
/// multithreading? | |
static var libraryCache: [ | |
GEMMKernelDescriptor: GEMMKernel] = [:] | |
/// WARNING: Not thread safe. But will the DSL interpreter even use | |
/// multithreading? | |
static var pipelineCache: [ | |
GEMMDescriptor: (GEMMKernel, MTLComputePipelineState)] = [:] | |
} | |
/// Implementation of the logic for choosing between 'device' and | |
/// 'threadgroup' store. | |
func retrieveGEMMKernel( | |
descriptor gemmDesc: GEMMDescriptor | |
) -> (GEMMKernel, MTLComputePipelineState) { | |
// Perform the early return before anything with high latency. | |
if let value = GEMMKernel.pipelineCache[gemmDesc] { | |
return value | |
} | |
func createKernel(descriptor: GEMMKernelDescriptor) -> GEMMKernel { | |
guard descriptor.preferAsyncStore != nil else { | |
fatalError("Prefer async store was not set.") | |
} | |
if let previous = GEMMKernel.libraryCache[descriptor] { | |
return previous | |
} else { | |
return GEMMKernel(descriptor: descriptor) | |
} | |
} | |
// Create a MTLDevice object, a function call with very high latency. | |
let device = MTLCreateSystemDefaultDevice()! | |
func createPipeline(library: MTLLibrary) -> MTLComputePipelineState { | |
// Set the function constants. | |
let constants = MTLFunctionConstantValues() | |
var M = gemmDesc.matrixDimensions!.M | |
var N = gemmDesc.matrixDimensions!.N | |
var K = gemmDesc.matrixDimensions!.K | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
let function = try! library.makeFunction( | |
name: "gemm", constantValues: constants) | |
let pipeline = try! device.makeComputePipelineState(function: function) | |
return pipeline | |
} | |
var kernelDesc = GEMMKernelDescriptor(descriptor: gemmDesc) | |
kernelDesc.device = device | |
if device.supportsFamily(.apple9) { | |
kernelDesc.preferAsyncStore = false | |
} else { | |
guard let blockDimensions = kernelDesc.blockDimensions else { | |
fatalError("Block dimensions were not set.") | |
} | |
if blockDimensions == (48, 48, 32) { | |
kernelDesc.preferAsyncStore = nil | |
} else { | |
kernelDesc.preferAsyncStore = true | |
} | |
} | |
var output: (GEMMKernel, MTLComputePipelineState) | |
if kernelDesc.preferAsyncStore != nil { | |
let kernel = createKernel(descriptor: kernelDesc) | |
let pipeline = createPipeline(library: kernel.library) | |
output = (kernel, pipeline) | |
GEMMKernel.libraryCache[kernelDesc] = kernel | |
} else { | |
var candidates: [ | |
(kernelDesc: GEMMKernelDescriptor, | |
kernel: GEMMKernel, | |
pipeline: MTLComputePipelineState) | |
] = [] | |
for candidateID in 0..<4 { | |
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16) | |
var preferAsyncStore: Bool | |
switch candidateID { | |
case 0: | |
blockDimensions = (48, 48, 32) | |
preferAsyncStore = false | |
case 1: | |
blockDimensions = (48, 48, 40) | |
preferAsyncStore = false | |
case 2: | |
blockDimensions = (48, 48, 32) | |
preferAsyncStore = true | |
case 3: | |
blockDimensions = (48, 48, 40) | |
preferAsyncStore = true | |
default: | |
fatalError("This should never happen.") | |
} | |
// Set the data that's unique to this variant. | |
var newKernelDesc = kernelDesc | |
newKernelDesc.blockDimensions = blockDimensions | |
newKernelDesc.preferAsyncStore = preferAsyncStore | |
let kernel = createKernel(descriptor: newKernelDesc) | |
let pipeline = createPipeline(library: kernel.library) | |
candidates.append((newKernelDesc, kernel, pipeline)) | |
GEMMKernel.libraryCache[newKernelDesc] = kernel | |
} | |
// Find the maximum occupancy. | |
var maximumOccupancy: Int = -1 | |
for candidate in candidates { | |
let occupancy = candidate.pipeline.maxTotalThreadsPerThreadgroup | |
maximumOccupancy = max(maximumOccupancy, occupancy) | |
} | |
candidates.removeAll(where: { | |
$0.pipeline.maxTotalThreadsPerThreadgroup != maximumOccupancy | |
}) | |
// Choose the highest-performing candidate. | |
let candidate = candidates.last! | |
kernelDesc = candidate.kernelDesc | |
output = (candidate.kernel, candidate.pipeline) | |
} | |
// Save the output to the cache. | |
GEMMKernel.pipelineCache[gemmDesc] = output | |
return output | |
} | |
// MARK: - Profiling | |
/// A continuous (integration) test of both correctness and performance. This | |
/// test completes with low latency (\<1 second) for rapid feedback during | |
/// iterative design. | |
/// | |
/// Returns: | |
/// - lane 0: maximum achieved performance in GFLOPS | |
/// - lane 1: occupancy in threads/core | |
func profileProblemSize( | |
descriptor: GEMMDescriptor | |
) -> SIMD2<Int> { | |
guard let matrixDimensions = descriptor.matrixDimensions, | |
matrixDimensions.M == matrixDimensions.N, | |
matrixDimensions.M == matrixDimensions.K else { | |
fatalError("Matrix dimensions were invalid.") | |
} | |
let problemSize = Int(matrixDimensions.M) | |
// Allocate FP32 memory for the operands. | |
var A = [Float](repeating: .zero, count: problemSize * problemSize) | |
var B = [Float](repeating: .zero, count: problemSize * problemSize) | |
var C = [Float](repeating: .zero, count: problemSize * problemSize) | |
// Initialize A as the 2nd-order periodic Laplacian. | |
for diagonalID in 0..<problemSize { | |
let diagonalAddress = diagonalID * problemSize + diagonalID | |
A[diagonalAddress] = -2 | |
let leftColumnID = (diagonalID + problemSize - 1) % problemSize | |
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID | |
A[leftSubDiagonalAddress] = 1 | |
let rightColumnID = (diagonalID + problemSize + 1) % problemSize | |
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID | |
A[rightSubDiagonalAddress] = 1 | |
} | |
// Initialize B to random numbers. | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = Float.random(in: 0..<1) | |
B[address] = entry | |
} | |
} | |
// Since the Laplacian is symmetric, we swap roles of the matrices to test | |
// transposition of the left-hand side. | |
// | |
// Note that the test cannot cover correctness of A and B transposition | |
// simultaneously. Instead, test the correctness in isolation | |
// (AB, AB^T, A^T B). Performance can be tested in all four permutations | |
// (AB, AB^T, A^T B, A^T B^T). | |
if descriptor.transposeState!.A { | |
swap(&A, &B) | |
} | |
// Initialize the context. | |
let device = MTLCreateSystemDefaultDevice()! | |
let commandQueue = device.makeCommandQueue()! | |
let context = (device: device, commandQueue: commandQueue) | |
func createBuffer( | |
_ originalData: [Float], | |
_ precision: GEMMOperandPrecision | |
) -> MTLBuffer { | |
// Add random numbers to expose out-of-bounds accesses. | |
var augmentedData = originalData | |
for _ in 0..<originalData.count { | |
let randomNumber = Float.random(in: -2...2) | |
augmentedData.append(randomNumber) | |
} | |
// Allocate enough memory to store everything in Float32. | |
let bufferSize = augmentedData.count * 4 | |
let buffer = context.device.makeBuffer(length: bufferSize)! | |
// Copy the data into the buffer. | |
switch precision { | |
case .FP32: | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
for i in augmentedData.indices { | |
pointer[i] = augmentedData[i] | |
} | |
case .FP16: | |
let pointer = buffer.contents().assumingMemoryBound(to: Float16.self) | |
for i in augmentedData.indices { | |
pointer[i] = Float16(augmentedData[i]) | |
} | |
case .BF16: | |
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self) | |
for i in augmentedData.indices { | |
let value32 = augmentedData[i].bitPattern | |
let value16 = unsafeBitCast(value32, to: SIMD2<UInt16>.self)[1] | |
pointer[i] = value16 | |
} | |
} | |
return buffer | |
} | |
// Multiply A with B. | |
var maxGFLOPS: Int = .zero | |
var occupancy: Int = .zero | |
do { | |
// Generate the kernel. | |
let (kernel, pipeline) = retrieveGEMMKernel(descriptor: descriptor) | |
occupancy = pipeline.maxTotalThreadsPerThreadgroup | |
// Create the buffers. | |
let bufferA = createBuffer(A, descriptor.memoryPrecisions!.A) | |
let bufferB = createBuffer(B, descriptor.memoryPrecisions!.B) | |
let bufferC = createBuffer(C, descriptor.memoryPrecisions!.C) | |
// Profile the latency of matrix multiplication. | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Encode the GPU command. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setThreadgroupMemoryLength( | |
Int(kernel.threadgroupMemoryAllocation), index: 0) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
for _ in 0..<duplicatedCommandCount { | |
func ceilDivide(_ target: Int, _ granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(problemSize, kernel.blockDimensions.N), | |
height: ceilDivide(problemSize, kernel.blockDimensions.M), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: Int(kernel.threadgroupSize), | |
height: 1, | |
depth: 1) | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize) | |
} | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
// let latencyMicroseconds = Int(latency / 1e-6) | |
// print(latencyMicroseconds, "μs", gflops, "GFLOPS") | |
maxGFLOPS = max(maxGFLOPS, gflops) | |
} | |
// Copy the results to C. | |
do { | |
let precision = descriptor.memoryPrecisions!.C | |
let raw = bufferC.contents() | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
var entry32: Float | |
switch precision { | |
case .FP32: | |
let casted = raw.assumingMemoryBound(to: Float.self) | |
entry32 = casted[address] | |
case .FP16: | |
let casted = raw.assumingMemoryBound(to: Float16.self) | |
let entry16 = casted[address] | |
entry32 = Float(entry16) | |
case .BF16: | |
let casted = raw.assumingMemoryBound(to: UInt16.self) | |
let entry16 = casted[address] | |
let entry16x2 = SIMD2<UInt16>(.zero, entry16) | |
entry32 = unsafeBitCast(entry16x2, to: Float.self) | |
} | |
C[address] = entry32 | |
} | |
} | |
} | |
} | |
// Choose an error threshold. | |
func createErrorThreshold(precision: GEMMOperandPrecision) -> Float { | |
switch precision { | |
case .FP32: return 1e-5 | |
case .FP16: return 5e-3 | |
case .BF16: return 5e-2 | |
} | |
} | |
var errorThreshold: Float = 0 | |
do { | |
let memoryPrecisions = descriptor.memoryPrecisions! | |
let thresholdA = createErrorThreshold(precision: memoryPrecisions.A) | |
let thresholdB = createErrorThreshold(precision: memoryPrecisions.B) | |
let thresholdC = createErrorThreshold(precision: memoryPrecisions.C) | |
errorThreshold = max(errorThreshold, thresholdA) | |
errorThreshold = max(errorThreshold, thresholdB) | |
errorThreshold = max(errorThreshold, thresholdC) | |
} | |
// Check the results. | |
var errorCount: Int = .zero | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
// Find the source row IDs. | |
let leftRowID = (m + problemSize - 1) % problemSize | |
let centerRowID = m | |
let rightRowID = (m + problemSize + 1) % problemSize | |
// Find the source scalars. | |
var leftSource: Float | |
var centerSource: Float | |
var rightSource: Float | |
if descriptor.transposeState!.A { | |
leftSource = A[leftRowID * problemSize + n] | |
centerSource = A[centerRowID * problemSize + n] | |
rightSource = A[rightRowID * problemSize + n] | |
} else if descriptor.transposeState!.B { | |
leftSource = B[n * problemSize + leftRowID] | |
centerSource = B[n * problemSize + centerRowID] | |
rightSource = B[n * problemSize + rightRowID] | |
} else { | |
leftSource = B[leftRowID * problemSize + n] | |
centerSource = B[centerRowID * problemSize + n] | |
rightSource = B[rightRowID * problemSize + n] | |
} | |
// Find the expected result. | |
let expected = leftSource - 2 * centerSource + rightSource | |
// Find the actual result. | |
var actual: Float | |
if descriptor.transposeState!.A { | |
actual = C[n * problemSize + m] | |
} else { | |
actual = C[m * problemSize + n] | |
} | |
// Report whether it is correct. | |
let error = (expected - actual).magnitude | |
if error > errorThreshold { | |
if errorCount < 10 { | |
print("error: \(error) / ~1.000") | |
errorCount += 1 | |
} | |
} | |
} | |
} | |
return SIMD2(maxGFLOPS, occupancy) | |
} | |
/// The workspace for scripting everything during kernel design. Invoked by | |
/// SwiftUI in the iOS app for M4 testing. | |
func runApplication() { | |
struct TestDescriptor { | |
var precision: GEMMOperandPrecision? | |
var problemSize: Int? | |
var transposeState: (Bool, Bool)? | |
} | |
func runTest(descriptor: TestDescriptor) { | |
guard let precision = descriptor.precision, | |
let problemSize = descriptor.problemSize, | |
let transposeState = descriptor.transposeState else { | |
fatalError("Descriptor was incomplete.") | |
} | |
// Set up the kernel. | |
var gemmDesc = GEMMDescriptor() | |
let n = UInt32(problemSize) | |
gemmDesc.matrixDimensions = (M: n, N: n, K: n) | |
gemmDesc.memoryPrecisions = (precision, precision, precision) | |
gemmDesc.transposeState = descriptor.transposeState | |
// Test the kernel. | |
let statistic = profileProblemSize(descriptor: gemmDesc) | |
// Report the results. | |
do { | |
var repr = "\(problemSize)" | |
while repr.count < 4 { | |
repr = " " + repr | |
} | |
print("problemSize = \(repr)", terminator: " | ") | |
} | |
if transposeState.0 { | |
print("A^T", terminator: " ") | |
} else { | |
print("A ", terminator: " ") | |
} | |
if transposeState.1 { | |
print("B^T", terminator: " | ") | |
} else { | |
print("B ", terminator: " | ") | |
} | |
for laneID in [Int(1), Int(0)] { | |
var repr = "\(statistic[laneID])" | |
while repr.count < 4 { | |
repr = " " + repr | |
} | |
// Log the number to the console. | |
if laneID == 0 { | |
print(repr, terminator: " GFLOPS") | |
} else { | |
print(repr, terminator: " threads/core | ") | |
} | |
} | |
print("") | |
} | |
// Correctness tests. | |
do { | |
let problemSizes: [Int] = [ | |
7, 8, 9, 10, | |
15, 16, 17, 18, | |
23, 24, 25, | |
31, 32, 33, | |
47, 48, 49, | |
63, 64, 65, | |
103, 104, 112, | |
126, 127, 128, 129, | |
130, 131, | |
135, 136, 137, | |
143, 144, 145, | |
151, 152, 153, | |
] | |
let transposeStates: [(Bool, Bool)] = [ | |
(false, false), | |
(false, true), | |
(true, false), | |
] | |
print() | |
print("Correctness tests:") | |
for problemSize in problemSizes { | |
for transposeState in transposeStates { | |
var testDescriptor = TestDescriptor() | |
testDescriptor.precision = .FP32 | |
testDescriptor.problemSize = problemSize | |
testDescriptor.transposeState = transposeState | |
runTest(descriptor: testDescriptor) | |
} | |
} | |
} | |
// My workspace. Edit this code to run actual tests. | |
do { | |
let transposeStates: [(Bool, Bool)] = [ | |
(false, false), | |
(false, true), | |
(true, false), | |
(true, true), | |
] | |
// Working on investigating BF16 performance with large matrices. | |
print() | |
print("Performance tests:") | |
for problemSize in 1488...1489 { | |
for transposeState in transposeStates { | |
var testDescriptor = TestDescriptor() | |
testDescriptor.precision = .BF16 | |
testDescriptor.problemSize = problemSize | |
testDescriptor.transposeState = transposeState | |
runTest(descriptor: testDescriptor) | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment