Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active July 22, 2024 15:46
Show Gist options
  • Save philipturner/84f613a5cc745460a914d2c6ad226131 to your computer and use it in GitHub Desktop.
Save philipturner/84f613a5cc745460a914d2c6ad226131 to your computer and use it in GitHub Desktop.
Single shader source that supports every hardware architecture, problem size, and precision
//
// main.swift
// UnifiedGEMMKernel
//
// Created by Philip Turner on 5/29/24.
//
import Metal
#if os(macOS)
import IOKit
#endif
// Single shader source that supports every hardware architecture, problem
// size, and precision.
//
// ========================================================================== //
// Introduction
// ========================================================================== //
//
// There should not be a tradeoff between:
// - Reaching theoretical maximum ALU performance (within a factor of 1.01)
// - Reaching theoretical minimum compilation latency (provided you have an
// engine for caching shader variants)
// - Having short, legible, and portable source code
// - Supporting every GPU generation and data type
//
// The existing solutions out there compromise on one of these. For example,
// MPS and Mojo are closed source, meaning the code is not legible. Other ML
// frameworks are ergonomic and easy to use, but sacrifice performance.
// Performance is most often neglected on older chips (M1, M2) or edge cases
// (when matrix dimensions are not divisible by block size).
//
// Recently, I found a way to access SIMD async copy instructions from the JIT
// compiler\*. This removes the need to use Xcode 14.2 in the shader compilation
// process. Instead of encoding all arguments in function constants, some of
// them can be encoded directly into a JIT-compiled shader source. This
// freedom allows for simpler shader code and a simpler interface for running
// the kernel. For example, the client no longer needs to allocate threadgroup
// memory at runtime.
//
// > \*This and some additional context is organized at:
// > https://gist.github.com/philipturner/0dd518be6705544778474c4e9f8fd68d
//
// The code should be robust against worst-case situations. Imagine a PyTorch
// workflow where someone explores a hyperparameter space. They try changing
// the size of an MLP from 128 neurons, to 129 neurons, 130, etc. Each call
// into a matrix multiplication has a different problem size. If a new kernel
// was cached/compiled from scratch for each call, the latency would become a
// bottleneck. Yet, knowing the problem size beforehand is critical to reaching
// maximum performance (and outperforming MPS). The code should meet the
// standards of something that replaces MPS across thousands of end-user
// workflows.
//
// The first attempt will not be perfect. I will learn things, try again with
// new implementations written from scratch. Continue distilling, simplifying,
// making the code more robust. Compare against alternatives (MPS, MLX) with a
// data-driven, evidence-driven approach.
//
// ========================================================================== //
// Methods
// ========================================================================== //
//
// The following design specifications were drafted for the source file.
//
// Anything with a fixed number of options that would require adding control
// flow blocks to the shader.
// - Device architecture
// - Blocking setup, threadgroup memory allocation
// - Operand precisions, accumulator precision
// Convention: Injected into shader source.
//
// Remaining parts that must be known ahead of time for reasonable performance.
// - Remainder of the integer division: problem size / block size
// - Transpose state of operands
// Convention: Choice between injection into shader source, or function
// constants.
//
// Parts where inlining a constant into assembly would maximize performance.
// - Actual problem size
// Convention: Choice between injection into shader source, function constants,
// or runtime data structure.
//
// Draft a kernel that does all of these things. It doesn't need optimal
// performance; just needs to implement them correctly.
//
// ========================================================================== //
// Lab Notes
// ========================================================================== //
//
// First issue:
// - The optimized BF16 decoding function from a previous experiment does not
// support transposed representations in memory.
// - It also didn't support loading/storing from threadgroup memory.
// - The previous experiment didn't implement all the components of the BF16 ->
// FP32 decoding optimization for M1.
//
// Second issue:
// - simdgroup_matrix_storage.load/store doesn't support unaligned inputs.
// - Previous design delegated that to the SIMD async copy unit in hardware.
// - Likely need a better alternative on M3/M4, where async copy is emulated
// and hence very slow.
//
// Solution:
// - Three code paths for loading/storing from memory, until I prove the Apple
// GPU 'device_load' instruction natively supports alignments smaller than
// the vector size (e.g. alignment is the scalar size).
// - Transposed (2 instructions)
// - Untransposed, not aligned to multiple of 2 (2 instructions)
// - Untransposed, aligned to multiple of 2 (1 instruction)
// - Use code generation to spawn the header with compact Swift code.
//
// Third issue:
// - While trying to get rid of async copies, I kept finding edge cases. For
// example, when the matrix dimension is an odd number, and RAM accesses go
// out of bounds. Fixing these requires either incurring immense overhead at
// matrix edges, or baking if/else statements into source code.
//
// I found something interesting. If I avoid async copies for most of the inner
// loop iterations, I can get decent performance on M4. This is a slightly
// modified version of the M1 kernel, which divides the "k" accumulator into
// two sections. The first section reads the inputs directly from device memory.
// The last section reads from threadgroup memory.
//
// This modified kernel frequently causes IOCommandBuffer errors on M1. I need
// to understand why that is happening.
//
// ========================================================================== //
// Tuning Performance on M4
// ========================================================================== //
//
// Avoiding regressions on M1 Max:
//
// The kernel must achieve 8100 GFLOPS @ 1535x1535, 48x48x24.
// The kernel must achieve 8150 GFLOPS @ 1536x1536, 48x48x24.
// The kernel must achieve 7530 GFLOPS @ 1537x1537, 48x48x24.
//
// Reference statistics for M1 Max:
//
// - problemSize = 256 | 913 GFLOPS (32x32x8)
// - problemSize = 384 | 2931 GFLOPS (32x32x8)
// - problemSize = 512 | 5342 GFLOPS (32x32x8)
// - problemSize = 640 | 5463 GFLOPS (32x32x8) 6440 GFLOPS (async copy)
// - problemSize = 768 | 6160 GFLOPS (48x48x8) 7017 GFLOPS (async copy)
// - problemSize = 896 | 6643 GFLOPS (48x48x8) 7136 GFLOPS (async copy)
// - problemSize = 1024 | 7596 GFLOPS (48x48x8) 6966 GFLOPS (async copy)
// - problemSize = 1152 | 7676 GFLOPS (48x48x8) 8144 GFLOPS (async copy)
// - problemSize = 1280 | 7712 GFLOPS (48x48x8) 7813 GFLOPS (async copy)
// - problemSize = 1408 | 7747 GFLOPS (48x48x8)
// - problemSize = 1536 | 8392 GFLOPS (48x48x8)
//
// Performance target on M4:
//
// - problemSize = 256 | 1195 GFLOPS (32x32x8) 590 GFLOPS (MPS)
// - problemSize = 384 | 1729 GFLOPS (32x32x8) 1105 GFLOPS (MPS)
// - problemSize = 512 | 2549 GFLOPS (32x32x8) 2051 GFLOPS (MPS)
// - problemSize = 640 | 2983 GFLOPS (32x32x8) 3028 GFLOPS (MPS)
// - problemSize = 768 | 3036 GFLOPS (32x32x8) 3087 GFLOPS (MPS)
// - problemSize = 896 | 3044 GFLOPS (32x32x8) 3086 GFLOPS (MPS)
// - problemSize = 1024 | 3074 GFLOPS (32x32x8) 3125 GFLOPS (MPS)
// - problemSize = 1152 | 3123 GFLOPS (32x32x8) 3152 GFLOPS (MPS)
// - problemSize = 1280 | 3134 GFLOPS (32x32x8) 3134 GFLOPS (MPS)
// - problemSize = 1408 | 3167 GFLOPS (32x32x8) 3150 GFLOPS (MPS)
// - problemSize = 1536 | 3174 GFLOPS (32x32x8) 3129 GFLOPS (MPS)
//
// Performance deterioration for odd problem sizes:
// M1 Max (32x32, async copy), M4 (32x32, no async copy)
//
// - problemSize = 254 | 1888 GFLOPS 950 GFLOPS
// - problemSize = 255 | 1950 GFLOPS 971 GFLOPS
// - problemSize = 256 | 2087 GFLOPS 1210 GFLOPS
// - problemSize = 257 | 1744 GFLOPS 907 GFLOPS
// - problemSize = 258 | 1754 GFLOPS 921 GFLOPS
//
// - problemSize = 510 | 5296 GFLOPS 2614 GFLOPS
// - problemSize = 511 | 5266 GFLOPS 2624 GFLOPS
// - problemSize = 512 | 5390 GFLOPS 2765 GFLOPS
// - problemSize = 513 | 5180 GFLOPS 2365 GFLOPS
// - problemSize = 514 | 5208 GFLOPS 2377 GFLOPS
//
// - problemSize = 1022 | 5989 GFLOPS 3054 GFLOPS
// - problemSize = 1023 | 5905 GFLOPS 3059 GFLOPS
// - problemSize = 1024 | 7164 GFLOPS 3051 GFLOPS
// - problemSize = 1025 | 5618 GFLOPS 2880 GFLOPS
// - problemSize = 1026 | 5770 GFLOPS 2905 GFLOPS
//
// Overall scaling for aligned problem sizes:
// M4 (32x32, no async copy) vs M4 (MPS)
//
// - problemSize = 256 | 1205 GFLOPS vs 590 GFLOPS (MPS)
// - problemSize = 384 | 1711 GFLOPS vs 1105 GFLOPS (MPS)
// - problemSize = 512 | 2747 GFLOPS vs 2051 GFLOPS (MPS)
// - problemSize = 640 | 2939 GFLOPS vs 3028 GFLOPS (MPS)
// - problemSize = 768 | 3010 GFLOPS vs 3087 GFLOPS (MPS)
// - problemSize = 896 | 3024 GFLOPS vs 3086 GFLOPS (MPS)
// - problemSize = 1024 | 3040 GFLOPS vs 3125 GFLOPS (MPS)
// - problemSize = 1152 | 3101 GFLOPS vs 3152 GFLOPS (MPS)
// - problemSize = 1280 | 3101 GFLOPS vs 3134 GFLOPS (MPS)
// - problemSize = 1408 | 3130 GFLOPS vs 3150 GFLOPS (MPS)
// - problemSize = 1536 | 3140 GFLOPS vs 3129 GFLOPS (MPS)
//
// The above results were without some potential optimizations. For example,
// fusing device_load instructions for consecutive matrix elements when the
// matrix dimension is not divisible by 2. In addition, eliding the async
// copies when writing the C matrix to memory on M4.
//
// ========================================================================== //
// Tuning Precisions
// ========================================================================== //
//
// ## M1 Max
//
// Configuration:
// - maximum of 32x32x32 and 48x48x24/32
// - inputs are not transposed
//
// memA | memB | memC | regA | regB | regC | 512 | 768 | 1024 | 1280 | 1536 |
// ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 | 6754 | 7274 | 7604 | 7891 | 8300 |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 | 6794 | 7244 | 7578 | 7864 | 8299 |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 | 6651 | 7224 | 7630 | 7920 | 8322 |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 | 6560 | 7231 | 7632 | 7924 | 8318 |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 | 6118 | 6624 | 6912 | 7381 | 7590 |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 | 6398 | 7159 | 7031 | 7632 | 8232 |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 | 5556 | 6779 | 6529 | 7297 | 7680 |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 | 5223 | 7149 | 6898 | 7753 | 8112 |
//
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | 6796 | 7879 | 8331 | 8396 | 8492 |
// FP16 | FP16 | FP32 | FP32 | FP32 | FP32 | 6727 | 7848 | 8306 | 8396 | 8490 |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 | 6156 | 6935 | 7221 | 7535 | 7841 |
// FP16 | BF16 | FP32 | FP16 | FP32 | FP32 | 6304 | 7157 | 7617 | 7894 | 8142 |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 | 6347 | 7130 | 7628 | 7904 | 8152 |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 | 6246 | 7348 | 7722 | 7758 | 7878 |
// BF16 | FP16 | FP32 | FP32 | FP16 | FP32 | 6435 | 7328 | 7324 | 7865 | 8260 |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 | 6352 | 7339 | 7312 | 7865 | 8259 |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 | 5787 | 6598 | 6993 | 6967 | 7207 |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 | 6075 | 6967 | 6955 | 7515 | 7888 |
//
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | 7077 | 8535 | 8096 | 8660 | 9136 |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP32 | 6946 | 8561 | 8322 | 8696 | 9103 |
// FP16 | FP16 | FP16 | FP16 | FP32 | FP16 | 6384 | 7742 | 7496 | 7879 | 8254 |
// FP16 | FP16 | FP16 | FP32 | FP16 | FP16 | 6350 | 7747 | 7476 | 7875 | 8263 |
// FP16 | FP16 | FP16 | FP32 | FP32 | FP32 | 7124 | 8505 | 8330 | 8702 | 9091 |
//
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 | 5861 | 7356 | 7084 | 7426 | 7720 |
// BF16 | BF16 | BF16 | BF16 | FP32 | FP32 | 5676 | 7805 | 7415 | 7724 | 8250 |
// BF16 | BF16 | BF16 | FP32 | BF16 | FP32 | 6243 | 6998 | 7031 | 7355 | 7724 |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 | 6367 | 7210 | 7086 | 7544 | 7930 |
//
// FP32 | FP32 | FP16 | FP32 | FP32 | FP16 | 5738 | 6450 | 6130 | 6741 | 7312 |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | 5420 | 7084 | 7171 | 7739 | 8223 |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | 6452 | 7173 | 7200 | 7804 | 8243 |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | 5368 | 7074 | 7165 | 7740 | 8225 |
//
// FP16 | BF16 | FP16 | FP16 | FP32 | FP16 | 5908 | 6873 | 7076 | 7417 | 7745 |
// FP16 | BF16 | FP16 | FP16 | FP32 | FP32 | 6566 | 7598 | 7617 | 7993 | 8489 |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP16 | 5896 | 6873 | 7070 | 7419 | 7739 |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP32 | 5891 | 7602 | 7616 | 8011 | 8494 |
// FP16 | BF16 | BF16 | FP16 | FP32 | FP32 | 6650 | 7581 | 7572 | 7987 | 8462 |
// FP16 | BF16 | BF16 | FP32 | FP32 | FP32 | 6809 | 7534 | 7562 | 8027 | 8467 |
//
// BF16 | FP16 | FP16 | FP32 | FP16 | FP16 | 5937 | 6923 | 7108 | 7527 | 7930 |
// BF16 | FP16 | FP16 | FP32 | FP16 | FP32 | 6414 | 7554 | 7333 | 7981 | 8459 |
// BF16 | FP16 | FP16 | FP32 | FP32 | FP16 | 5976 | 6924 | 7083 | 7525 | 7929 |
// BF16 | FP16 | FP16 | FP32 | FP32 | FP32 | 5383 | 7559 | 7310 | 8008 | 8456 |
// BF16 | FP16 | BF16 | FP32 | FP16 | FP32 | 6575 | 7573 | 7512 | 8028 | 8469 |
// BF16 | FP16 | BF16 | FP32 | FP32 | FP32 | 6699 | 7530 | 7525 | 8000 | 8470 |
//
// BF16 | BF16 | FP16 | FP32 | FP32 | FP16 | 5756 | 6721 | 6780 | 7220 | 7583 |
// BF16 | BF16 | FP16 | FP32 | FP32 | FP32 | 6424 | 7206 | 6954 | 7547 | 7923 |
//
// Optimal register precisions for each memory precision.
//
// Truth Table:
//
// memA | memB | memC | regA | regB | regC |
// ---- | ---- | ---- | ---- | ---- | ---- |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 |
// FP16 | FP16 | BF16 | FP16 | FP16 | FP32 |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 |
// FP16 | BF16 | FP16 | FP32 | FP32 | FP32 |
// FP16 | BF16 | BF16 | FP32 | FP32 | FP32 |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 |
// FP16 | FP32 | FP16 | FP32 | FP32 | FP32 |
// FP16 | FP32 | BF16 | FP32 | FP32 | FP32 |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 |
//
// BF16 | FP16 | FP16 | FP32 | FP32 | FP32 |
// BF16 | FP16 | BF16 | FP32 | FP32 | FP32 |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 |
// BF16 | BF16 | FP16 | FP32 | FP32 | FP32 |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 |
// BF16 | FP32 | FP16 | FP32 | FP32 | FP32 |
// BF16 | FP32 | BF16 | FP32 | FP32 | FP32 |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 |
//
// FP32 | FP16 | FP16 | FP32 | FP32 | FP32 |
// FP32 | FP16 | BF16 | FP32 | FP32 | FP32 |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 |
// FP32 | BF16 | FP16 | FP32 | FP32 | FP32 |
// FP32 | BF16 | BF16 | FP32 | FP32 | FP32 |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 |
//
// Optimized form of the logic:
//
// If memA and memB are FP16,
// regA is FP16
// regB is FP16
// else
// regA is FP32
// regB is FP32
// If memA, memB, and memC are FP16,
// regC is FP16
// else
// regC is FP32
//
// ## M4
//
// Configuration:
// - 32x32x8
// - inputs are not transposed
//
// memA | memB | memC | regA | regB | regC | 512 | 768 | 1024 | 1280 | 1536 |
// ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 | 2864 | 3093 | 3128 | 3201 | 3237 |
// FP16 | FP32 | FP32 | FP32 | FP32 | FP32 | 2823 | 3093 | 3128 | 3197 | 3236 |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 | 2871 | 3136 | 3180 | 3238 | 3276 |
// FP32 | FP16 | FP32 | FP32 | FP32 | FP32 | 2853 | 3133 | 3180 | 3239 | 3275 |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 | 2832 | 3089 | 3122 | 3199 | 3237 |
// BF16 | FP32 | FP32 | FP32 | FP32 | FP32 | 2735 | 2981 | 2999 | 3067 | 3102 |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 | 2852 | 3134 | 3189 | 3241 | 3276 |
// FP32 | BF16 | FP32 | FP32 | FP32 | FP32 | 2742 | 3007 | 3053 | 3104 | 3139 |
//
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 | 2985 | 3285 | 3342 | 3417 | 3458 |
// FP16 | FP16 | FP32 | FP32 | FP32 | FP32 | 2987 | 3285 | 3348 | 3416 | 3458 |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 | 3017 | 3290 | 3344 | 3417 | 3460 |
// FP16 | BF16 | FP32 | FP16 | FP32 | FP32 | 2861 | 3130 | 3179 | 3244 | 3282 |
// FP16 | BF16 | FP32 | FP32 | BF16 | FP32 | 2988 | 3281 | 3349 | 3419 | 3459 |
// FP16 | BF16 | FP32 | FP32 | FP32 | FP32 | 2887 | 3127 | 3177 | 3244 | 3281 |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 | 2990 | 3284 | 3339 | 3418 | 3458 |
// BF16 | FP16 | FP32 | BF16 | FP32 | FP32 | 2712 | 3277 | 3332 | 3413 | 3459 |
// BF16 | FP16 | FP32 | FP32 | FP16 | FP32 | 2836 | 3108 | 3168 | 3221 | 3258 |
// BF16 | FP16 | FP32 | FP32 | FP32 | FP32 | 2855 | 3108 | 3162 | 3220 | 3257 |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 | 2978 | 3285 | 3347 | 3417 | 3458 |
// BF16 | BF16 | FP32 | BF16 | FP32 | FP32 | 2868 | 3119 | 3182 | 3245 | 3281 |
// BF16 | BF16 | FP32 | FP32 | BF16 | FP32 | 2850 | 3114 | 3164 | 3218 | 3257 |
// BF16 | BF16 | FP32 | FP32 | FP32 | FP32 | 2707 | 2943 | 2991 | 3038 | 3070 |
//
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 | 3069 | 3334 | 3410 | 3471 | 3512 |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP32 | 3025 | 3264 | 3318 | 3394 | 3430 |
// FP16 | FP16 | FP16 | FP16 | FP32 | FP16 | 2708 | 2932 | 2990 | 3039 | 3072 |
// FP16 | FP16 | FP16 | FP32 | FP16 | FP16 | 2714 | 2928 | 2987 | 3041 | 3074 |
// FP16 | FP16 | FP16 | FP32 | FP32 | FP32 | 3016 | 3260 | 3315 | 3392 | 3432 |
//
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 | 3025 | 3280 | 3324 | 3409 | 3445 |
// BF16 | BF16 | BF16 | BF16 | FP32 | FP32 | 2876 | 3111 | 3164 | 3233 | 3268 |
// BF16 | BF16 | BF16 | FP32 | BF16 | FP32 | 2857 | 3093 | 3138 | 3206 | 3241 |
// BF16 | BF16 | BF16 | FP32 | FP32 | FP32 | 2707 | 2883 | 2939 | 2982 | 3014 |
//
// FP32 | FP32 | FP16 | FP32 | FP32 | FP16 | 2535 | 2791 | 2769 | 2828 | 2861 |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 | 2781 | 2999 | 3022 | 3077 | 3115 |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 | 2809 | 3052 | 3090 | 3141 | 3178 |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 | 2772 | 3017 | 3044 | 3101 | 3141 |
//
// FP16 | BF16 | FP16 | FP16 | BF16 | FP16 | 2702 | 2937 | 2984 | 3040 | 3072 |
// FP16 | BF16 | FP16 | FP16 | BF16 | FP32 | 3010 | 3258 | 3317 | 3392 | 3431 |
// FP16 | BF16 | FP16 | FP32 | BF16 | FP16 | 2712 | 2938 | 2990 | 3040 | 3073 |
// FP16 | BF16 | FP16 | FP32 | BF16 | FP32 | 2997 | 3260 | 3314 | 3394 | 3431 |
// FP16 | BF16 | BF16 | FP16 | BF16 | FP32 | 3028 | 3275 | 3324 | 3403 | 3447 |
// FP16 | BF16 | BF16 | FP32 | BF16 | FP32 | 3008 | 3274 | 3322 | 3407 | 3443 |
//
// BF16 | FP16 | FP16 | BF16 | FP16 | FP16 | 2705 | 2938 | 2987 | 3042 | 3072 |
// BF16 | FP16 | FP16 | BF16 | FP16 | FP32 | 3011 | 3255 | 3310 | 3392 | 3433 |
// BF16 | FP16 | FP16 | BF16 | FP32 | FP16 | 2698 | 2939 | 2988 | 3039 | 3073 |
// BF16 | FP16 | FP16 | BF16 | FP32 | FP32 | 3015 | 3261 | 3313 | 3391 | 3431 |
// BF16 | FP16 | BF16 | BF16 | FP16 | FP32 | 3027 | 3276 | 3323 | 3406 | 3445 |
// BF16 | FP16 | BF16 | BF16 | FP32 | FP32 | 3040 | 3275 | 3327 | 3407 | 3445 |
//
// BF16 | BF16 | FP16 | BF16 | BF16 | FP16 | 2709 | 2938 | 2987 | 3038 | 3073 |
// BF16 | BF16 | FP16 | BF16 | BF16 | FP32 | 3005 | 3260 | 3320 | 3394 | 3433 |
//
// Optimal register precisions for each memory precision.
//
// Truth Table:
//
// memA | memB | memC | regA | regB | regC |
// ---- | ---- | ---- | ---- | ---- | ---- |
// FP16 | FP16 | FP16 | FP16 | FP16 | FP16 |
// FP16 | FP16 | BF16 | FP16 | FP16 | FP32 |
// FP16 | FP16 | FP32 | FP16 | FP16 | FP32 |
// FP16 | BF16 | FP16 | FP16 | BF16 | FP32 |
// FP16 | BF16 | BF16 | FP16 | BF16 | FP32 |
// FP16 | BF16 | FP32 | FP16 | BF16 | FP32 |
// FP16 | FP32 | FP16 | FP16 | FP32 | FP32 |
// FP16 | FP32 | BF16 | FP16 | FP32 | FP32 |
// FP16 | FP32 | FP32 | FP16 | FP32 | FP32 |
//
// BF16 | FP16 | FP16 | BF16 | FP16 | FP32 |
// BF16 | FP16 | BF16 | BF16 | FP16 | FP32 |
// BF16 | FP16 | FP32 | BF16 | FP16 | FP32 |
// BF16 | BF16 | FP16 | BF16 | BF16 | FP32 |
// BF16 | BF16 | BF16 | BF16 | BF16 | FP32 |
// BF16 | BF16 | FP32 | BF16 | BF16 | FP32 |
// BF16 | FP32 | FP16 | BF16 | FP32 | FP32 |
// BF16 | FP32 | BF16 | BF16 | FP32 | FP32 |
// BF16 | FP32 | FP32 | BF16 | FP32 | FP32 |
//
// FP32 | FP16 | FP16 | FP32 | FP16 | FP32 |
// FP32 | FP16 | BF16 | FP32 | FP16 | FP32 |
// FP32 | FP16 | FP32 | FP32 | FP16 | FP32 |
// FP32 | BF16 | FP16 | FP32 | BF16 | FP32 |
// FP32 | BF16 | BF16 | FP32 | BF16 | FP32 |
// FP32 | BF16 | FP32 | FP32 | BF16 | FP32 |
// FP32 | FP32 | FP16 | FP32 | FP32 | FP32 |
// FP32 | FP32 | BF16 | FP32 | FP32 | FP32 |
// FP32 | FP32 | FP32 | FP32 | FP32 | FP32 |
//
// Optimized form of the logic:
//
// regA is identical to memA
// regB is identical to memB
// If memA, memB, and memC are FP16,
// regC is FP16
// else
// regC is FP32
//
// ## Conclusion
//
// Use a common set of logic for both architectures:
//
// ```
// regA is identical to memA
// regB is identical to memB
// If memA, memB, and memC are FP16,
// regC is FP16
// else
// regC is FP32
//
// If earlier than M3
// If memA is BF16,
// regA is FP32
// If memB is BF16,
// regB is FP32
// ```
//
// ========================================================================== //
// Tuning Transposes
// ========================================================================== //
//
// I am considering a simplification to the code, which merges two control
// blocks in simdgroup_matrix_storage.load.
//
// M1 Max, FP32xFP32->FP32
//
// Original code:
// problemSize = 1023 | A B | 6941 GFLOPS
// problemSize = 1023 | A B^T | 6781 GFLOPS
// problemSize = 1023 | A^T B | 7082 GFLOPS
// problemSize = 1023 | A^T B^T | 6763 GFLOPS
// problemSize = 1024 | A B | 6995 GFLOPS
// problemSize = 1024 | A B^T | 6901 GFLOPS
// problemSize = 1024 | A^T B | 7173 GFLOPS
// problemSize = 1024 | A^T B^T | 6926 GFLOPS
// problemSize = 1025 | A B | 7035 GFLOPS
// problemSize = 1025 | A B^T | 6699 GFLOPS
// problemSize = 1025 | A^T B | 7269 GFLOPS
// problemSize = 1025 | A^T B^T | 6927 GFLOPS
//
// With packed_vec for threadgroup loads:
// problemSize = 1023 | A B | 6925 GFLOPS
// problemSize = 1023 | A B^T | 6786 GFLOPS
// problemSize = 1023 | A^T B | 6977 GFLOPS
// problemSize = 1023 | A^T B^T | 6770 GFLOPS
// problemSize = 1024 | A B | 6979 GFLOPS
// problemSize = 1024 | A B^T | 6894 GFLOPS
// problemSize = 1024 | A^T B | 7109 GFLOPS
// problemSize = 1024 | A^T B^T | 6931 GFLOPS
// problemSize = 1025 | A B | 7005 GFLOPS
// problemSize = 1025 | A B^T | 6707 GFLOPS
// problemSize = 1025 | A^T B | 7225 GFLOPS
// problemSize = 1025 | A^T B^T | 6943 GFLOPS
//
// Significant decrease (-39 GFLOPS) for 1023.
// Significant decrease (-29 GFLOPS) for 1024.
// Significant decrease (-22 GFLOPS) for 1025.
//
// Conclusion:
//
// Originally, I thought the code change was safe. I had gathered evidence, and
// saw no detectable performance delta after the change. However, I had failed
// to switch a few source lines between being commented out and not. The
// corrected test showed a major regression when packed_vec<T, 2> was used for
// threadgroup memory accesses with power-2 row sizes.
//
// For odd numbered row sizes (and by extension, odd-numbered problem sizes in
// device memory), there is no change. Still, I am not going to switch any more
// lines of code to packed vectors. The evidence is too murky and the only
// benefit of the changes would be making code more legible.
//
// This is a good example of the thought process when writing high-performance
// code. You need reliable methods to gather data (without statistical noise),
// comparing whether each tiny change increases or decreases performance. The
// optimizations that harm performance or provide no net gain, you don't
// include. Assume there is always a regression in one of the combinatorial
// explosion of problem configurations. Find that regression and prove it
// doesn't exist.
//
// ========================================================================== //
// Tuning Async Copies on M3+
// ========================================================================== //
//
// Tasks:
// - In the device loading iterations, some garbage values are being read
// from out-of-bounds addresses. Fix this without harming performance.
// - Does the fix also remove the IOCommandBuffer errors on M1?
// - Try to further minimize the overhead of async copies on M4.
//
// I removed the alpha and beta constants, which simplifies the shader. Some of
// the fixes to the problems above, would require rearchitecting of the code
// for handling non-zero beta. Use cases that require stuff like accumulating
// into an existing accumulator, would likely require more flexibility than
// just a scalar argument in a pre-written kernel (e.g. one would write their
// own fused activation shader from scratch). Alpha/beta are technical debt
// from the decades-old BLAS interface and out of scope for this reference
// implementation.
//
// Next, I am changing how the code handles matrix edges.
//
// Correctness tests:
// - Precision: FP32xFP32->FP32
// - Problem size: (8, 16, 24, 32, 48, 64, 104, 128)(±1)
// - Transpose state: AB, AB^T, A^T B
//
// Performance tests:
// - Avoid any regressions from the following statistics.
// - Executed the program 3 times, took the maximum value of every row of
// the tables below. This approach minimizes statistical noise. The maximum
// performance achievable by two similar programs is more amenable to
// rigorous analysis than the average performance. Especially when the
// performance delta of some changes is so small, it would be drowned out
// by statistical noise.
//
// M1 Max:
// FP32 | problemSize = 1535 | A B | 8170 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7876 GFLOPS
// FP32 | problemSize = 1536 | A B | 8229 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8085 GFLOPS
// FP32 | problemSize = 1537 | A B | 7607 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7237 GFLOPS
// FP16 | problemSize = 1535 | A B | 9064 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 8978 GFLOPS
// FP16 | problemSize = 1536 | A B | 9163 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9063 GFLOPS
// FP16 | problemSize = 1537 | A B | 8558 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8555 GFLOPS
// BF16 | problemSize = 1535 | A B | 8261 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 8283 GFLOPS
// BF16 | problemSize = 1536 | A B | 8338 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8394 GFLOPS
// BF16 | problemSize = 1537 | A B | 7730 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7893 GFLOPS
//
// M4:
// FP32 | problemSize = 1023 | A B | 3030 GFLOPS
// FP32 | problemSize = 1023 | A B^T | 2783 GFLOPS
// FP32 | problemSize = 1024 | A B | 3042 GFLOPS
// FP32 | problemSize = 1024 | A B^T | 3048 GFLOPS
// FP32 | problemSize = 1025 | A B | 2859 GFLOPS
// FP32 | problemSize = 1025 | A B^T | 2650 GFLOPS
// FP16 | problemSize = 1023 | A B | 3352 GFLOPS
// FP16 | problemSize = 1023 | A B^T | 3288 GFLOPS
// FP16 | problemSize = 1024 | A B | 3407 GFLOPS
// FP16 | problemSize = 1024 | A B^T | 3397 GFLOPS
// FP16 | problemSize = 1025 | A B | 3182 GFLOPS
// FP16 | problemSize = 1025 | A B^T | 3117 GFLOPS
// BF16 | problemSize = 1023 | A B | 3319 GFLOPS
// BF16 | problemSize = 1023 | A B^T | 3266 GFLOPS
// BF16 | problemSize = 1024 | A B | 3324 GFLOPS
// BF16 | problemSize = 1024 | A B^T | 3317 GFLOPS
// BF16 | problemSize = 1025 | A B | 3151 GFLOPS
// BF16 | problemSize = 1025 | A B^T | 3088 GFLOPS
//
// I got correctness working with the change to how edges are processed.
// Next, I need to get performance working.
//
// M1 Max:
// FP32 | problemSize = 1535 | A B | 8195 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7830 GFLOPS
// FP32 | problemSize = 1536 | A B | 8231 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8090 GFLOPS
// FP32 | problemSize = 1537 | A B | 7559 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7148 GFLOPS
// FP16 | problemSize = 1535 | A B | 9087 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 8993 GFLOPS
// FP16 | problemSize = 1536 | A B | 9154 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9071 GFLOPS
// FP16 | problemSize = 1537 | A B | 8547 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8464 GFLOPS
// BF16 | problemSize = 1535 | A B | 8209 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 8237 GFLOPS
// BF16 | problemSize = 1536 | A B | 8343 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8395 GFLOPS
// BF16 | problemSize = 1537 | A B | 7704 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7748 GFLOPS
//
// -14 GFLOPS for 1535.
// +2 GFLOPS for 1536.
// -68 GFLOPS for 1537.
// A larger performance drop for 1537 would be expected. Almost every element
// in the edge block is redundantly computed and written.
//
// M4:
// FP32 | problemSize = 1023 | A B | 3068 GFLOPS
// FP32 | problemSize = 1023 | A B^T | 2800 GFLOPS
// FP32 | problemSize = 1024 | A B | 3071 GFLOPS
// FP32 | problemSize = 1024 | A B^T | 3040 GFLOPS
// FP32 | problemSize = 1025 | A B | 2887 GFLOPS
// FP32 | problemSize = 1025 | A B^T | 2647 GFLOPS
// FP16 | problemSize = 1023 | A B | 3398 GFLOPS
// FP16 | problemSize = 1023 | A B^T | 3323 GFLOPS
// FP16 | problemSize = 1024 | A B | 3421 GFLOPS
// FP16 | problemSize = 1024 | A B^T | 3402 GFLOPS
// FP16 | problemSize = 1025 | A B | 3220 GFLOPS
// FP16 | problemSize = 1025 | A B^T | 3158 GFLOPS
// BF16 | problemSize = 1023 | A B | 3329 GFLOPS
// BF16 | problemSize = 1023 | A B^T | 3276 GFLOPS
// BF16 | problemSize = 1024 | A B | 3317 GFLOPS
// BF16 | problemSize = 1024 | A B^T | 3327 GFLOPS
// BF16 | problemSize = 1025 | A B | 3154 GFLOPS
// BF16 | problemSize = 1025 | A B^T | 3099 GFLOPS
//
// +26 GFLOPS for 1023.
// +8 GFLOPS for 1024.
// +20 GFLOPS for 1025.
//
// If every write goes through threadgroup memory (async copy) instead of
// device memory on M1, here is the alternative perf delta.
//
// FP32 | problemSize = 1535 | A B | 8226 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7862 GFLOPS
// FP32 | problemSize = 1536 | A B | 8251 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8124 GFLOPS
// FP32 | problemSize = 1537 | A B | 7615 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7202 GFLOPS
// FP16 | problemSize = 1535 | A B | 9011 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 9012 GFLOPS
// FP16 | problemSize = 1536 | A B | 9162 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9049 GFLOPS
// FP16 | problemSize = 1537 | A B | 8575 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8539 GFLOPS
// BF16 | problemSize = 1535 | A B | 7980 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 7954 GFLOPS
// BF16 | problemSize = 1536 | A B | 8352 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8399 GFLOPS
// BF16 | problemSize = 1537 | A B | 7716 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7906 GFLOPS
//
// -98 GFLOPS for 1535.
// +11 GFLOPS for 1536.
// -5 GFLOPS for 1537.
//
// The evidence reveals this:
// - write directly to device memory:
// -14 GFLOPS for 1535
// +2 GFLOPS for 1536
// - write indirectly via threadgroup memory:
// -5 GFLOPS for 1537
//
// Focus on three optimizations:
// - Allowing extraneous writes to device or threadgroup memory to be elided.
// - Minimizing the overhead of pointer addressing arithmetic, restoring the
// performance of the kernel before this change.
// - Avoiding bank conflicts when writing the accumulator on M1 (if
// threadgroup memory is a common execution path).
//
// Eliding unnecessary writes, and going through device memory:
//
// M1 Max
// FP32 | problemSize = 1535 | A B | 8199 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7828 GFLOPS
// FP32 | problemSize = 1536 | A B | 8239 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8100 GFLOPS
// FP32 | problemSize = 1537 | A B | 7535 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7145 GFLOPS
// FP16 | problemSize = 1535 | A B | 9092 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 8995 GFLOPS
// FP16 | problemSize = 1536 | A B | 9175 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9077 GFLOPS
// FP16 | problemSize = 1537 | A B | 8508 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8488 GFLOPS
// BF16 | problemSize = 1535 | A B | 8224 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 8245 GFLOPS
// BF16 | problemSize = 1536 | A B | 8349 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8402 GFLOPS
// BF16 | problemSize = 1537 | A B | 7708 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7866 GFLOPS
//
// -8 GFLOPS for 1535.
// +12 GFLOPS for 1536.
// -60 GFLOPS for 1537.
//
// M4
// FP32 | problemSize = 1023 | A B | 3064 GFLOPS
// FP32 | problemSize = 1023 | A B^T | 2783 GFLOPS
// FP32 | problemSize = 1024 | A B | 3064 GFLOPS
// FP32 | problemSize = 1024 | A B^T | 3039 GFLOPS
// FP32 | problemSize = 1025 | A B | 2886 GFLOPS
// FP32 | problemSize = 1025 | A B^T | 2647 GFLOPS
// FP16 | problemSize = 1023 | A B | 3405 GFLOPS
// FP16 | problemSize = 1023 | A B^T | 3333 GFLOPS
// FP16 | problemSize = 1024 | A B | 3417 GFLOPS
// FP16 | problemSize = 1024 | A B^T | 3400 GFLOPS
// FP16 | problemSize = 1025 | A B | 3217 GFLOPS
// FP16 | problemSize = 1025 | A B^T | 3163 GFLOPS
// BF16 | problemSize = 1023 | A B | 3331 GFLOPS
// BF16 | problemSize = 1023 | A B^T | 3277 GFLOPS
// BF16 | problemSize = 1024 | A B | 3318 GFLOPS
// BF16 | problemSize = 1024 | A B^T | 3325 GFLOPS
// BF16 | problemSize = 1025 | A B | 3154 GFLOPS
// BF16 | problemSize = 1025 | A B^T | 3097 GFLOPS
//
// +26 GFLOPS for 1023.
// +5 GFLOPS for 1024.
// +20 GFLOPS for 1025.
//
// Eliding unnecessary writes, and going through threadgroup memory:
//
// M1 Max
// FP32 | problemSize = 1535 | A B | 8221 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7861 GFLOPS
// FP32 | problemSize = 1536 | A B | 8249 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8126 GFLOPS
// FP32 | problemSize = 1537 | A B | 7611 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7203 GFLOPS
// FP16 | problemSize = 1535 | A B | 8995 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 9012 GFLOPS
// FP16 | problemSize = 1536 | A B | 9190 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9046 GFLOPS
// FP16 | problemSize = 1537 | A B | 8574 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8512 GFLOPS
// BF16 | problemSize = 1535 | A B | 7988 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 7962 GFLOPS
// BF16 | problemSize = 1536 | A B | 8358 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8405 GFLOPS
// BF16 | problemSize = 1537 | A B | 7709 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7872 GFLOPS
//
// -99 GFLOPS for 1535.
// +17 GFLOPS for 1536.
// -17 GFLOPS for 1537.
//
// M4
// FP32 | problemSize = 1023 | A B | 3065 GFLOPS
// FP32 | problemSize = 1023 | A B^T | 2795 GFLOPS
// FP32 | problemSize = 1024 | A B | 3065 GFLOPS
// FP32 | problemSize = 1024 | A B^T | 3030 GFLOPS
// FP32 | problemSize = 1025 | A B | 2874 GFLOPS
// FP32 | problemSize = 1025 | A B^T | 2648 GFLOPS
// FP16 | problemSize = 1023 | A B | 3375 GFLOPS
// FP16 | problemSize = 1023 | A B^T | 3313 GFLOPS
// FP16 | problemSize = 1024 | A B | 3403 GFLOPS
// FP16 | problemSize = 1024 | A B^T | 3384 GFLOPS
// FP16 | problemSize = 1025 | A B | 3192 GFLOPS
// FP16 | problemSize = 1025 | A B^T | 3139 GFLOPS
// BF16 | problemSize = 1023 | A B | 3334 GFLOPS
// BF16 | problemSize = 1023 | A B^T | 3256 GFLOPS
// BF16 | problemSize = 1024 | A B | 3344 GFLOPS
// BF16 | problemSize = 1024 | A B^T | 3308 GFLOPS
// BF16 | problemSize = 1025 | A B | 3156 GFLOPS
// BF16 | problemSize = 1025 | A B^T | 3093 GFLOPS
//
// +17 GFLOPS for 1023.
// -0 GFLOPS for 1024.
// +9 GFLOPS for 1025.
//
// I want to test whether any performance changes are simply due to changes
// to the pointer addressing code. I will modify the code so that no memory
// writes are elided. Then, benchmark the following configurations:
// - M1 Max
// - 1535, to device memory
// - 1536, to device memory
// - 1537, to threadgroup memory
// - M4
// - 1023, to device memory
// - 1024, to device memory
// - 1025, to device memory
//
// M1 Max
// FP32 | problemSize = 1535 | A B | 8192 GFLOPS
// FP32 | problemSize = 1535 | A B^T | 7834 GFLOPS
// FP32 | problemSize = 1536 | A B | 8245 GFLOPS
// FP32 | problemSize = 1536 | A B^T | 8097 GFLOPS
// FP32 | problemSize = 1537 | A B | 7618 GFLOPS
// FP32 | problemSize = 1537 | A B^T | 7206 GFLOPS
// FP16 | problemSize = 1535 | A B | 9079 GFLOPS
// FP16 | problemSize = 1535 | A B^T | 9007 GFLOPS
// FP16 | problemSize = 1536 | A B | 9183 GFLOPS
// FP16 | problemSize = 1536 | A B^T | 9078 GFLOPS
// FP16 | problemSize = 1537 | A B | 8581 GFLOPS
// FP16 | problemSize = 1537 | A B^T | 8541 GFLOPS
// BF16 | problemSize = 1535 | A B | 8219 GFLOPS
// BF16 | problemSize = 1535 | A B^T | 8247 GFLOPS
// BF16 | problemSize = 1536 | A B | 8351 GFLOPS
// BF16 | problemSize = 1536 | A B^T | 8399 GFLOPS
// BF16 | problemSize = 1537 | A B | 7723 GFLOPS
// BF16 | problemSize = 1537 | A B^T | 7908 GFLOPS
//
// -9 GFLOPS for 1535.
// +14 GFLOPS for 1536.
// -1 GFLOPS for 1537.
//
// M4
// FP32 | problemSize = 1023 | A B | 3067 GFLOPS
// FP32 | problemSize = 1023 | A B^T | 2780 GFLOPS
// FP32 | problemSize = 1024 | A B | 3072 GFLOPS
// FP32 | problemSize = 1024 | A B^T | 3038 GFLOPS
// FP32 | problemSize = 1025 | A B | 2888 GFLOPS
// FP32 | problemSize = 1025 | A B^T | 2644 GFLOPS
// FP16 | problemSize = 1023 | A B | 3394 GFLOPS
// FP16 | problemSize = 1023 | A B^T | 3325 GFLOPS
// FP16 | problemSize = 1024 | A B | 3418 GFLOPS
// FP16 | problemSize = 1024 | A B^T | 3400 GFLOPS
// FP16 | problemSize = 1025 | A B | 3220 GFLOPS
// FP16 | problemSize = 1025 | A B^T | 3158 GFLOPS
// BF16 | problemSize = 1023 | A B | 3330 GFLOPS
// BF16 | problemSize = 1023 | A B^T | 3275 GFLOPS
// BF16 | problemSize = 1024 | A B | 3329 GFLOPS
// BF16 | problemSize = 1024 | A B^T | 3325 GFLOPS
// BF16 | problemSize = 1025 | A B | 3153 GFLOPS
// BF16 | problemSize = 1025 | A B^T | 3098 GFLOPS
//
// +22 GFLOPS for 1023.
// +8 GFLOPS for 1024.
// +19 GFLOPS for 1025.
//
// The write elision optimization has failed, so I need to revert it. I did
// succeed at shifting the matrix origin in-bounds, allowing M4 to avoid async
// stores for almost every matrix in practical use cases. The new code has
// minimal changes to M1 performance and non-negligible improvements to M4
// performance. Averaging the ALU utilization for all precisions:
//
// M1 Max (max 10617 GFLOPS)
// 1535: 79.5% -> 79.4%
// 1536: 80.5% -> 80.6%
// 1537: 74.7% -> 74.7%
//
// M4 (max ~3580 GFLOPS)
// 1023: 88.6% -> 89.3%
// 1024: 90.9% -> 91.2%
// 1025: 84.0% -> 84.5%
//
// The M1 architecture changes only by rounding error (0.1%). The M3
// architecture increases by anywhere from +0.3% to +0.7%. This evidence seems
// high-quality enough to justify the recent refactoring of the shader.
//
// There is one slight complication. To avoid a regression on M1, I need to
// switch between device and threadgroup writing based on the result of
// (problem size / block size). I'm going to investigate whether there are
// bank conflicts on M1. If so, that could bias the 'threadgroup' performance
// to always be the highest. I would not need to engineer logic for switching
// between shader variants.
//
// ========================================================================== //
// Tuning BF16 on M1-M2
// ========================================================================== //
//
// I gathered some evidence, and noticed strange patterns between which is
// faster for M1 / BF16: device or threadgroup stores?
//
// https://gist.github.com/philipturner/1c157bf87702420e51c6deb68e70d078
//
// Here are some raw chat messages from Discord:
//
// > Today I found out a lot of strange things about my shader, that may be the
// > reason it's faster than MPS.
// >
// > The K dimension of the block (accumulator) is being unrolled.
// >
// > But all the registers for K are being allocated. Normally a SIMD allocates
// > 32 x 32 for accumulator, 32 x 8 for A, 32 x 8 for B. This is the way MPS
// > does it, and how any other Apple GPU (e.g. M3) does it. How the very old
// > Tinygrad reference implementation did it. Also the fastest when SIMD async
// > copy is not used.
// >
// > MFA has 24 x 24 for accumulator, 24 x 24 for A, and 24 x 24 for B. So at
// > first glance, almost 3 times the register pressure if the accumulator only
// > was stored. That means occupancy is terribly low. Revealed itself for BF16
// > where although threadgroup memory is BF16, registers are FP32.
// >
// > I tried fixing it, but every alternative lowered performance. Apparently
// > the biggest bottleneck is either addressing overhead or latency of waiting
// > on threadgroup -> thread loads (somehow the GPU needs a heck ton of
// > latency hiding and wants to load ~24-40 inner loop iterations of data up
// > front).
// >
// > MPS would not have found this strange performance maximum for a number of
// > reasons. They don't have good support for matrices indivisible by block
// > size. 32 x 32 + 32 x 16 + 32 x 16 would cause severe enough register
// > pressure to be nonviable. Only 24 x 24 + 24 x 24 + 24 x 24 can work with
// > this. And it needs to be paired with async copy, because the small M/N in
// > registers requires high bandwidth. Only possible when reading from
// > threadgroup memory (higher bandwidth than device memory).
// >
// > For FP16 the maximum is 24 x 24 + 24 x 32 + 24 x 32 (registers are FP16)
// > and BF16 is strangely 24 x 24 + 24 x 32 + 24 x 32 when the registers are
// > FP32. Extremely low occupancy for this performance maximum. That's why the
// > heuristics for kernel selection logic will be tricky.
// >
// > I think I can design the logic for M1 + BF16 in a day.
//
// That is exactly what I will do to complete the shader. Compile a large
// volume of performance data, create truth tables, and solve the inverse
// problem of predicting the fastest configuration.
//
// Data sheet:
//
// https://gist.github.com/philipturner/45781f4515145106fc0d4e598dd5f13b
//
// Insights from the data:
//
// Always allocate an accumulator worth of threadgroup memory, even if it will
// never be used. This speeds up M3 for some reason. Except for a -50 to -100
// GFLOPS regression when all operands are FP16 and the matrix dimension is
// indivisible by 8. Most importantly, the choice will simplify the code. This
// is an example of deciding between various pros and cons of an optimization.
//
// > Gathering and interpreting data (should be) science, but deciding how to
// > act on that data is engineering. A general rule: choose simpler logic
// > that will extrapolate beyond the training data. Favor design choices
// > that are robust against edge cases. Even if it sacrifices performance on
// > benchmarks.
//
// Next, I found the primary reason for inconsistent performance with M1 +
// BF16. I had not finished optimizing the FP32 -> BF16 encoding when storing
// data to memory. The fastest path is the one that uses two separate
// store instructions. There is no evidence this change should be applied to
// precisions besides BF16. I will modify the codegen for 'store_bfloat' in
// 'simdgroup_matrix_storage'.
//
// Finally, whether to store the accumulator with an async copy.
//
// ```
// if the matrix dimensions M and/or N are smaller than the corresponding
// block size
// use async copy
// else if M3+
// do not use async copy
// else
// continue to next statement
//
// if the matrix is small (32 x 32 x 32 block)
// decouple the concern of optimizing small matrices from the concern of
// async stores
//
// use async copy by default
// else if any operand is FP32
// decouple the concern of eliminating bank conflicts from the concern of
// async stores
//
// use async copy by default
// else
// continue to next statement
//
// assert that block size is (M = 48, N = 48, K = 32)
// compile four kernel variants:
// - store directly to device, K = 32
// - store directly to device, K = 40
// - store indirectly to threadgroup, K = 32
// - store indirectly to threadgroup, K = 40
// create a MTLComputePipelineState for each variant
//
// sort the pipelines by occupancy
// if two pipelines have a different occupancy
// the higher occupancy wins
// else
// choose the last pipeline (they're ordered by increasing performance)
// ```
//
// Data for validating correct implementation of the logic:
//
// M1 Max, BF16
//
// K = 32
// problemSize = 1488 | A B | 896 threads/core | 8360 GFLOPS
// problemSize = 1488 | A B^T | 1024 threads/core | 8680 GFLOPS
// problemSize = 1488 | A^T B | 1024 threads/core | 8790 GFLOPS
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9036 GFLOPS
// problemSize = 1489 | A B | 768 threads/core | 7975 GFLOPS
// problemSize = 1489 | A B^T | 832 threads/core | 8086 GFLOPS
// problemSize = 1489 | A^T B | 768 threads/core | 8177 GFLOPS
// problemSize = 1489 | A^T B^T | 832 threads/core | 8454 GFLOPS
//
// K = 40
// problemSize = 1488 | A B | 640 threads/core | 8183 GFLOPS
// problemSize = 1488 | A B^T | 704 threads/core | 8421 GFLOPS
// problemSize = 1488 | A^T B | 768 threads/core | 8654 GFLOPS
// problemSize = 1488 | A^T B^T | 768 threads/core | 8890 GFLOPS
// problemSize = 1489 | A B | 768 threads/core | 8018 GFLOPS
// problemSize = 1489 | A B^T | 832 threads/core | 8382 GFLOPS
// problemSize = 1489 | A^T B | 768 threads/core | 8320 GFLOPS
// problemSize = 1489 | A^T B^T | 832 threads/core | 8637 GFLOPS
//
// Correctness was confirmed:
//
// problemSize = 1488 | A B | 896 threads/core | 8358 GFLOPS
// problemSize = 1488 | A B^T | 1024 threads/core | 8674 GFLOPS
// problemSize = 1488 | A^T B | 1024 threads/core | 8794 GFLOPS
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9024 GFLOPS
// problemSize = 1489 | A B | 768 threads/core | 8028 GFLOPS
// problemSize = 1489 | A B^T | 832 threads/core | 8374 GFLOPS
// problemSize = 1489 | A^T B | 832 threads/core | 8358 GFLOPS
// problemSize = 1489 | A^T B^T | 832 threads/core | 8645 GFLOPS
//
// This concludes the development of the unified GEMM kernel. The only further
// changes to this GitHub gist will be bug fixes and typo fixes.
//
// ========================================================================== //
// ========================================================================== //
// ========================================================================== //
// ## Addendum
//
// I fixed the issues with the shader caching mechanism. It now has a nearly
// 100% hit rate for the library cache. Interestingly, both MTLLibrary creation
// and MTLComputePipelineState specialization are roughly 30 ms. So the latency
// only decreased from 60 ms to 30 ms in my tests.
//
// Expected performance:
//
// problemSize = 511 | A B | 1024 threads/core | 6706 GFLOPS
// problemSize = 511 | A B^T | 896 threads/core | 5631 GFLOPS
// problemSize = 511 | A^T B | 896 threads/core | 5824 GFLOPS
// problemSize = 511 | A^T B^T | 1024 threads/core | 6991 GFLOPS
// problemSize = 512 | A B | 1024 threads/core | 6842 GFLOPS
// problemSize = 512 | A B^T | 1024 threads/core | 6938 GFLOPS
// problemSize = 512 | A^T B | 896 threads/core | 5933 GFLOPS
// problemSize = 512 | A^T B^T | 1024 threads/core | 7208 GFLOPS
//
// problemSize = 1488 | A B | 896 threads/core | 8375 GFLOPS
// problemSize = 1488 | A B^T | 1024 threads/core | 8685 GFLOPS
// problemSize = 1488 | A^T B | 1024 threads/core | 8804 GFLOPS
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9039 GFLOPS
// problemSize = 1489 | A B | 768 threads/core | 8049 GFLOPS
// problemSize = 1489 | A B^T | 832 threads/core | 8387 GFLOPS
// problemSize = 1489 | A^T B | 832 threads/core | 8381 GFLOPS
// problemSize = 1489 | A^T B^T | 832 threads/core | 8652 GFLOPS
//
// Performance after the change:
//
// problemSize = 511 | A B | 1024 threads/core | 6784 GFLOPS
// problemSize = 511 | A B^T | 896 threads/core | 5651 GFLOPS
// problemSize = 511 | A^T B | 896 threads/core | 5818 GFLOPS
// problemSize = 511 | A^T B^T | 1024 threads/core | 6997 GFLOPS
// problemSize = 512 | A B | 1024 threads/core | 6820 GFLOPS
// problemSize = 512 | A B^T | 1024 threads/core | 6915 GFLOPS
// problemSize = 512 | A^T B | 896 threads/core | 5966 GFLOPS
// problemSize = 512 | A^T B^T | 1024 threads/core | 7168 GFLOPS
//
// problemSize = 1488 | A B | 896 threads/core | 8369 GFLOPS
// problemSize = 1488 | A B^T | 1024 threads/core | 8678 GFLOPS
// problemSize = 1488 | A^T B | 1024 threads/core | 8808 GFLOPS
// problemSize = 1488 | A^T B^T | 1024 threads/core | 9040 GFLOPS
// problemSize = 1489 | A B | 768 threads/core | 8048 GFLOPS
// problemSize = 1489 | A B^T | 832 threads/core | 8391 GFLOPS
// problemSize = 1489 | A^T B | 832 threads/core | 8381 GFLOPS
// problemSize = 1489 | A^T B^T | 832 threads/core | 8648 GFLOPS
// MARK: - GEMM Kernel
/// An enumeration of the precisions supported by the kernel.
///
/// If you wish to support quantized precisions, copy/translate the source code
/// and integrate a modified version into your app. Something similar to a Swift
/// `enum` (e.g. C++ `enum class`) could enumerate the quantization formats
/// used by application code. An exemplary set could be:
/// - FP32
/// - FP16
/// - BF16
/// - signed 8-bit integer
/// - s1ezm7
/// - FP8
/// - palletized
///
/// If you support non-floating-point formats, you have the responsibility of
/// authoring correct and performant GPU code for them. A general rule of thumb,
/// is keep the data compressed in `device` or `threadgroup` memory. Transform
/// into a floating point type while loading into the registers. Keep the
/// accumulator in floating point until the output needs to be written.
/// If the output is quantized, it will be compressed when writing back to
/// `device` memory (or `threadgroup` before the async copy in edge cases).
///
/// For example, the reference implementation treats BF16 like a quantized
/// integer type on Apple7 and Apple8 GPUs. It is decompressed to FP32 in
/// registers.
enum GEMMOperandPrecision: UInt16 {
case FP32 = 0
case FP16 = 1
case BF16 = 2
}
/// A configuration for a GEMM kernel.
///
/// The information in this data structure is enough to uniquely identify the
/// kernel. It can be used as a key in a key-value cache.
///
/// ## Usage
///
/// The code for generating the GEMM kernel does not include any assumptions
/// about performance. It should only be responsible for correctly generating
/// a shader source, provided a configuration. The user is responsible for
/// choosing that configuration.
struct GEMMKernelDescriptor {
/// Required. The number of matrix elements spanned by each threadgroup.
/// - Parameter M: Number of output columns spanned.
/// - Parameter N: Number of output rows spanned.
/// - Parameter K: Number of loop iterations unrolled.
///
/// Optimal values:
/// - Apple7 and Apple8: 48x48x24
/// - Apple9 and later: 32x32x8
///
/// To reach optimal performance on Apple7 and Apple8, the recommended default
/// value needs to be modified conditionally. When all three operands have
/// 16-bit memory precisions, change `K` to 32. When the matrix is too small
/// to saturate all of the GPU cores, change all dimensions to 32x32x32. Even
/// smaller blocks can be exploited in low-occupancy cases, but 32x32 and
/// 48x48 are sufficient for general use.
///
/// For simplicity or an out-of-the-box performance test, one can assume
/// occupancy is always high. But to match the performance of MPS, one must
/// optimize for small problem sizes on large GPUs.
///
/// ## Choosing Block Size by Precision
///
/// Legend:
/// - memA: precision for left input matrix, in memory
/// - memB: precision for right input matrix, in memory
/// - memC: precision for output matrix, in memory
/// - regA: precision for left input matrix, in registers
/// - regB: precision for right input matrix, in registers
/// - regC: precision for output matrix, in registers
/// - M1: optimal block size on Apple7 and Apple8
/// - M3: optimal block size on Apple9 and later
///
/// memA | memB | memC | regA | regB | regC | M1 | M3 |
/// ---- | ---- | ---- | ---- | ---- | ---- | -------- | ------- |
/// FP16 | FP16 | FP16 | any | any | any | 48x48x32 | 32x32x8 |
/// BF16 | BF16 | BF16 | any | any | any | 48x48x32 | 32x32x8 |
/// FP16 | FP16 | FP32 | any | any | any | 48x48x24 | 32x32x8 |
/// BF16 | BF16 | FP32 | any | any | any | 48x48x24 | 32x32x8 |
/// FP16 | FP32 | FP16 | any | any | any | 48x48x24 | 32x32x8 |
/// BF16 | FP32 | BF16 | any | any | any | 48x48x24 | 32x32x8 |
/// FP32 | FP32 | FP32 | any | any | any | 48x48x24 | 32x32x8 |
///
/// ## Detecting Low-Occupancy Cases
///
/// To determine whether the matrix saturates the GPU, divide the output
/// matrix's dimensions by 48x48. Round up to the nearest integer. Then,
/// multiply the number of row blocks by the number of column blocks. The
/// result is the number of threadgroups dispatched. For example, a C matrix
/// with dimensions 768x768 would dispatch 256 threadgroups. If you are
/// batching multiple matrix multiplications into one shader call, multiply
/// the number of threadgroups by the batch count.
///
/// Next, calculate the target occupancy. Start by finding the GPU core count.
/// This can be accomplished in many ways; there is a heavily tested reference
/// implementation [here](https://github.com/philipturner/applegpuinfo). On
/// macOS, you can query the core count through IORegistry. On iOS, go with a
/// conservative (meaning more likely to overestimate) estimate of 5 cores on
/// A14 - A16, 10 cores on M1 - M2.
///
/// When one of the operands is 32-bit, the target occupancy is 6 threadgroups
/// per core. When all three operands are 16-bit, the target increases to 9
/// per core. Multiply the number of cores by the number of threadgroups per
/// core. If the total GPU occupancy is greater than or equal to the number of
/// matrix blocks, use the smaller blocking scheme.
///
/// For example, the following decision tree would be used on an M1 Max
/// (32 cores).
///
/// ```
/// is device Apple9 or later?
/// yes: use block size 32x32x8
/// no: continue decision tree [selected decision]
/// unsure: use block size 48x48x24-32
///
/// compute number of matrix blocks
/// 768x768 / 48x48 = 16.0 x 16.0
/// round floating point (16.0 x 16.0)
/// to next greatest integer (16 x 16)
/// 16 x 16 x (batch size of 1) = 256 threadgroups
///
/// compute target occupancies with 48x48 scheme
/// 32 x 6 = 192 [selected when A, B, or C is FP32]
/// 32 x 9 = 288 [selected when every matrix is FP16/BF16]
///
/// prefer 32x32 when 48x48 has low occupancy
/// if 256 ≤ 192
/// choose small block size (32x32x32xFP32)
/// else
/// choose large block size (48x48x24xFP32) [selected]
/// if 256 ≤ 288
/// choose small block size (32x32x32xFP16) [selected]
/// else
/// choose large block size (48x48x32xFP16)
/// ```
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16)?
var memoryPrecisions: (
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)?
/// The device to create the kernel on.
var device: MTLDevice?
/// Optional. The layout of elements in threadgroup memory.
///
/// If not specified, the default value matches the actual block dimensions.
///
/// This property can be used to avoid bank conflicts. For example, of one
/// operand will have 16 FP32 elements per row, there is good chance of
/// increased bank conflicts on M1. One may pad that threadgroup memory
/// allocation to 20 FP32 elements per row.
///
/// Note that the assignment of M/N/K to row dimensions varies based on which
/// operand is discussed, and what its transpose state is.
var paddedBlockDimensions: (
A: (M: UInt16, K: UInt16),
B: (K: UInt16, N: UInt16),
C: (M: UInt16, N: UInt16))?
/// Required. Whether async copies will improve performance during the
/// matrix multiplication loop.
///
/// The default value is `true`. Async copies improve performance on Apple7
/// and Apple8, but harm performance on Apple9 and later. However, they are
/// essential for correctness when reading from the edges of unaligned
/// matrices. Setting the value to `false` means skipping async copies when
/// doing so will not change the final result.
var preferAsyncLoad: Bool = true
/// Required. Whether async copies will improve performance when storing the
/// accumulator to main memory.
///
/// There is no default value that will reliably yield consistent performance.
var preferAsyncStore: Bool?
/// Set the register precision based on the GPU architecture, and your choice
/// for memory precision. The following set of logic statements should provide
/// optimal performance for all permutations of operand precisions.
///
/// ```
/// regA is identical to memA
/// regB is identical to memB
/// If memA, memB, and memC are FP16,
/// regC is FP16
/// else
/// regC is FP32
///
/// If earlier than M3
/// If memA is BF16,
/// regA is FP32
/// If memB is BF16,
/// regB is FP32
/// ```
var registerPrecisions: (
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)?
/// Required. The array of SIMDs to divide the threadgroup into.
///
/// Optimal values:
/// - Apple7 and Apple8: 2x2
/// - Apple9 and later: 1x1
var splits: (M: UInt16, N: UInt16)?
/// Required. Whether each of the inputs deviates from row-major order.
var transposeState: (A: Bool, B: Bool)?
}
struct GEMMKernelKey: Equatable, Hashable {
var blockDimensions: SIMD3<UInt16>
var memoryPrecisions: SIMD3<UInt16>
var paddedBlockDimensions: SIMD8<UInt16>
var preferAsyncLoad: UInt8
var preferAsyncStore: UInt8
var registerPrecisions: SIMD3<UInt16>
var splits: SIMD2<UInt16>
var transposeState: SIMD2<UInt8>
init(copying source: GEMMKernelDescriptor) {
blockDimensions = Self.createBlockDimensions(source.blockDimensions)
memoryPrecisions = Self.createPrecisions(source.memoryPrecisions)
paddedBlockDimensions = SIMD8(repeating: .max)
if let (A, B, C) = source.paddedBlockDimensions {
paddedBlockDimensions[0] = A.0
paddedBlockDimensions[1] = A.1
paddedBlockDimensions[2] = B.0
paddedBlockDimensions[3] = B.1
paddedBlockDimensions[4] = C.0
paddedBlockDimensions[5] = C.1
}
preferAsyncLoad = Self.createBoolean(source.preferAsyncLoad)
preferAsyncStore = Self.createBoolean(source.preferAsyncStore)
registerPrecisions = Self.createPrecisions(source.registerPrecisions)
splits = SIMD2(repeating: .max)
if let (M, N) = source.splits {
splits[0] = M
splits[1] = N
}
transposeState = Self.createTransposeState(source.transposeState)
}
@_transparent // performance in -Ounchecked
static func createBlockDimensions(
_ input: (UInt16, UInt16, UInt16)?
) -> SIMD3<UInt16> {
if let input {
return SIMD3(input.0, input.1, input.2)
} else {
return SIMD3(repeating: .max)
}
}
@_transparent // performance in -Ounchecked
static func createBoolean(
_ input: Bool?
) -> UInt8 {
if let input {
return input ? 1 : 0
} else {
return UInt8.max
}
}
@_transparent // performance in -Ounchecked
static func createPrecisions(
_ input: (
GEMMOperandPrecision, GEMMOperandPrecision, GEMMOperandPrecision)?
) -> SIMD3<UInt16> {
if let input {
return SIMD3(input.0.rawValue, input.1.rawValue, input.2.rawValue)
} else {
return SIMD3(repeating: .max)
}
}
@_transparent // performance in -Ounchecked
static func createTransposeState(
_ input: (Bool, Bool)?
) -> SIMD2<UInt8> {
if let input {
return SIMD2(input.0 ? 1 : 0,
input.1 ? 1 : 0)
} else {
return SIMD2(repeating: .max)
}
}
}
extension GEMMKernelDescriptor: Hashable, Equatable {
static func == (lhs: GEMMKernelDescriptor, rhs: GEMMKernelDescriptor) -> Bool {
let lhsKey = GEMMKernelKey(copying: lhs)
let rhsKey = GEMMKernelKey(copying: rhs)
return lhsKey == rhsKey
}
func hash(into hasher: inout Hasher) {
let key = GEMMKernelKey(copying: self)
hasher.combine(key)
}
}
struct GEMMKernel {
var library: MTLLibrary
var source: String = ""
// A copy of the block dimensions from the descriptor.
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16)
// If you allocate threadgroup memory after compiling the kernel, the code
// has higher performance.
var threadgroupMemoryAllocation: UInt16
// The number of threads per group.
var threadgroupSize: UInt16
init(descriptor: GEMMKernelDescriptor) {
guard let blockDimensions = descriptor.blockDimensions,
let device = descriptor.device,
let memoryPrecisions = descriptor.memoryPrecisions,
let preferAsyncStore = descriptor.preferAsyncStore,
let registerPrecisions = descriptor.registerPrecisions,
let splits = descriptor.splits,
let transposeState = descriptor.transposeState else {
fatalError("Descriptor was incomplete: \(descriptor)")
}
self.blockDimensions = blockDimensions
self.threadgroupSize = 32 * splits.M * splits.N
// Validate the correctness of register precisions.
func checkOperandPair(
memory: GEMMOperandPrecision,
register: GEMMOperandPrecision
) -> Bool {
// Truth table:
//
// memory | register | valid |
// ------ | -------- | ----- |
// FP32 | FP32 | yes |
// FP32 | FP16 | no |
// FP32 | BF16 | no |
// FP16 | FP32 | yes |
// FP16 | FP16 | yes |
// FP16 | BF16 | no |
// BF16 | FP32 | yes |
// BF16 | FP16 | no |
// BF16 | BF16 | yes |
//
// Optimized form of the logic:
//
// If the register precision matches the memory precision,
// return true
// If the register precision equals FP32,
// return true
// Otherwise,
// return false
//
// The logic statements will change if you introduce custom quantized
// formats. The truth table will grow exponentially. You'll need to add
// more restrictions on accepted pairs to overcome the combinatorial
// explosion.
if register == memory {
return true
} else if register == .FP32 {
return true
} else {
return false
}
}
guard checkOperandPair(
memory: memoryPrecisions.A, register: registerPrecisions.A) else {
fatalError("Operand A had an invalid register precision.")
}
guard checkOperandPair(
memory: memoryPrecisions.B, register: registerPrecisions.B) else {
fatalError("Operand B had an invalid register precision.")
}
guard checkOperandPair(
memory: memoryPrecisions.C, register: registerPrecisions.C) else {
fatalError("Operand C had an invalid register precision.")
}
if registerPrecisions.C == .BF16 {
// BF16 has too few mantissa bits to be an accurate accumulator. In
// addition, switching from FP32 accumulator to BF16 accumulator slows
// down execution speed on both M1/M2 and M3+.
fatalError("BF16 cannot be used as the register precision for C.")
}
// Inject the contents of the headers.
source += """
\(createMetalSimdgroupEvent())
\(createMetalSimdgroupMatrixStorage())
using namespace metal;
"""
// Declare the size of M and N within a register allocation.
let registerM: UInt16 = blockDimensions.M / splits.M
let registerN: UInt16 = blockDimensions.N / splits.N
// Retrieve the "padded" block dimensions, otherwise compute analytically
// from the true block dimensions.
var paddedBlockDimensionsA: (M: UInt16, K: UInt16)
var paddedBlockDimensionsB: (K: UInt16, N: UInt16)
var paddedBlockDimensionsC: (M: UInt16, N: UInt16)
if let paddedBlockDimensions = descriptor.paddedBlockDimensions {
paddedBlockDimensionsA = paddedBlockDimensions.A
paddedBlockDimensionsB = paddedBlockDimensions.B
paddedBlockDimensionsC = paddedBlockDimensions.C
} else {
paddedBlockDimensionsA = (blockDimensions.M, blockDimensions.K)
paddedBlockDimensionsB = (blockDimensions.K, blockDimensions.N)
paddedBlockDimensionsC = (blockDimensions.M, blockDimensions.N)
}
// Determine the block dimensions from the transpose state.
var leadingDimensionA: String
var leadingDimensionB: String
var leadingBlockDimensionA: UInt16
var leadingBlockDimensionB: UInt16
if transposeState.A {
leadingDimensionA = "M"
leadingBlockDimensionA = paddedBlockDimensionsA.M
} else {
leadingDimensionA = "K"
leadingBlockDimensionA = paddedBlockDimensionsA.K
}
if transposeState.B {
leadingDimensionB = "K"
leadingBlockDimensionB = paddedBlockDimensionsB.K
} else {
leadingDimensionB = "N"
leadingBlockDimensionB = paddedBlockDimensionsB.N
}
// Add the function constants.
do {
source += """
// Dimensions of each matrix.
// - Limitations to matrix size:
// - 2^32 in each dimension (M/N/K).
// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a
// good chance this will significantly degrade performance, and require
// changing the data type of several variables that process addresses. The
// client is responsible for ensuring correctness and performance with
// matrices spanning several billion elements in one direction.
// - The matrix dimensions must be known at compile time, via function
// constants. Dynamic matrix shapes are beyond the scope of this reference
// implementation. Dynamic shapes cause a non-negligible regression to
// shader execution speed. However, they could minimize a compilation
// latency bottleneck in some use cases.
// - Limitations to batch size:
// - Dictated by how the client modifies the code to implement batching.
// - Dynamic batch shapes would likely not harm performance much. For example,
// someone could enter an array of pointers/memory offsets to different
// matrices in the batch. Each slice of a 3D thread grid could read a
// different pointer from memory, and use that pointer as the A/B/C matrix.
// Another approach is to restrict the input format, so all matrices are
// stored contiguously in memory. Then, the memory offset could be computed
// analytically from matrix size and the Z dimension in a 3D thread grid.
//
// Another note:
// - The rows of the matrix must be contiguous in memory. Supporting strides
// that differ from the actual matrix dimensions should not be difficult, but
// it is out of scope for this reference kernel.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans = \(transposeState.A);
constant bool B_trans = \(transposeState.B);
// Define the memory layout of the matrix block.
constant ushort M_group = \(blockDimensions.M);
constant ushort N_group = \(blockDimensions.N);
constant ushort K_group = \(blockDimensions.K);
// Thresholds that mark the matrix edge.
constant uint M_edge = M - (M % M_group);
constant uint N_edge = N - (N % N_group);
// Find the number of elements in the final block. If the matrix
// dimensions are perfectly divisibly by block dimensions, we don't want
// this value to be zero. The final block is a full block.
constant ushort M_remainder = (M % \(registerM) == 0)
? \(registerM) : M % \(registerM);
constant ushort N_remainder = (N % \(registerN) == 0)
? \(registerN) : N % \(registerN);
constant ushort K_remainder = (K % K_group == 0)
? K_group : K % K_group;
constant ushort K_remainder_padded = (K_remainder + 7) / 8 * 8;
// Shift the final block, so it doesn't access out-of-bounds memory.
constant ushort M_shift = (M < M_group) ? 0 : \(registerM) - M_remainder;
constant ushort N_shift = (N < N_group) ? 0 : \(registerN) - N_remainder;
"""
}
// Allocate threadgroup memory, using the 'memory precision'. This memory
// is allocated at runtime, either by the user (explicit API call) or by
// the driver (behind the scenes).
func createPrecisionSize(_ precision: GEMMOperandPrecision) -> UInt16 {
// NOTE: Exotic precisions like some LLaMA quantization formats and ezm8
// have the exponent deinterleaved from the mantissa. Such precisions
// would require careful consideration of the meaning of per-scalar
// memory footprint.
switch precision {
case .FP32: return 4
case .FP16: return 2
case .BF16: return 2
}
}
// Allocate thread memory, using the 'register precision'. This memory
// is allocated by embedding the precision into the assembly code.
func createPrecisionName(_ precision: GEMMOperandPrecision) -> String {
// Exotic precisions would not require any special handling here. Good
// practices dictate that you decode to floating point while filling
// up the registers. Therefore, the registers will always be floating
// point.
switch precision {
case .FP32: return "float"
case .FP16: return "half"
case .BF16: return "bfloat"
}
}
// Determine the names of the operands.
let memoryNameA = createPrecisionName(memoryPrecisions.A)
let memoryNameB = createPrecisionName(memoryPrecisions.B)
let memoryNameC = createPrecisionName(memoryPrecisions.C)
let registerNameA = createPrecisionName(registerPrecisions.A)
let registerNameB = createPrecisionName(registerPrecisions.B)
let registerNameC = createPrecisionName(registerPrecisions.C)
// Add the utility functions.
source += """
// The layout of threads within a SIMD matrix.
//
// 0 0 1 1 8 8 9 9
// 2 2 3 3 10 10 11 11
// 4 4 5 5 12 12 13 13
// 6 6 7 7 14 14 15 15
// 16 16 17 17 24 24 25 25
// 18 18 19 19 26 26 27 27
// 20 20 21 21 28 28 29 29
// 22 22 23 23 30 30 31 31
//
// This is Morton order, a method for coalescing data accesses. It is used
// in a variety of contexts, from ray tracing acceleration structures, to
// nodal-point Laplacians, to sorting large lattices of atoms.
//
// Source: https://patents.google.com/patent/US11256518B2
METAL_FUNC ushort2 morton_order(ushort thread_index_in_simdgroup) {
ushort lane_id = thread_index_in_simdgroup;
ushort quad_id = lane_id / 4;
constexpr ushort QUADRANT_SPAN_M = 4;
constexpr ushort THREADS_PER_QUADRANT = 8;
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
return ushort2(N_in_simd, M_in_simd);
}
// Indexes into an array of registers.
//
// Calls to this function are expected to be evaluated at compile time. The
// array indices transform into register offsets, which are embedded into the
// assembly code.
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* get_sram(
thread simdgroup_matrix_storage<T> *sram,
ushort sram_leading_dim,
ushort2 matrix_origin
) {
return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8);
}
"""
struct MultiplyDescriptor {
var addressSpace: String?
var leadingDimensionA: String?
var leadingDimensionB: String?
var loadFunctionA: String?
var loadFunctionB: String?
}
func createMultiply(descriptor: MultiplyDescriptor) -> String {
guard let addressSpace = descriptor.addressSpace,
let leadingDimensionA = descriptor.leadingDimensionA,
let leadingDimensionB = descriptor.leadingDimensionB,
let loadFunctionA = descriptor.loadFunctionA,
let loadFunctionB = descriptor.loadFunctionB else {
fatalError("Descriptor was incomplete.")
}
return """
// One multiply-accumulate loop iteration, or 8 dot products.
METAL_FUNC void multiply_accumulate(
const \(addressSpace) \(memoryNameA) *A_src,
const \(addressSpace) \(memoryNameB) *B_src,
thread simdgroup_matrix_storage<\(registerNameA)> *A_sram,
thread simdgroup_matrix_storage<\(registerNameB)> *B_sram,
thread simdgroup_matrix_storage<\(registerNameC)> *C_sram,
ushort k
) {
#pragma clang loop unroll(full)
for (ushort m = 0; m < \(registerM); m += 8) {
ushort2 origin(0, m);
auto A = get_sram(A_sram, 8, origin);
A->\(loadFunctionA)(A_src, \(leadingDimensionA), ushort2(k, m), A_trans);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < \(registerN); n += 8) {
ushort2 origin(n, 0);
auto B = get_sram(B_sram, \(registerN), origin);
B->\(loadFunctionB)(B_src, \(leadingDimensionB), ushort2(n, k), B_trans);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < \(registerM); m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < \(registerN); n += 8) {
auto A = get_sram(A_sram, 8, ushort2(0, m));
auto B = get_sram(B_sram, \(registerN), ushort2(n, 0));
auto C = get_sram(C_sram, \(registerN), ushort2(n, m));
C->multiply(*A, *B);
}
}
}
"""
}
// Add the utility functions for the multiply-accumulate inner loop.
do {
var multiplyDesc = MultiplyDescriptor()
if memoryPrecisions.A == .BF16, registerPrecisions.A == .FP32 {
multiplyDesc.loadFunctionA = "load_bfloat"
} else {
multiplyDesc.loadFunctionA = "load"
}
if memoryPrecisions.B == .BF16, registerPrecisions.B == .FP32 {
multiplyDesc.loadFunctionB = "load_bfloat"
} else {
multiplyDesc.loadFunctionB = "load"
}
multiplyDesc.addressSpace = "device"
multiplyDesc.leadingDimensionA = leadingDimensionA
multiplyDesc.leadingDimensionB = leadingDimensionB
source += createMultiply(descriptor: multiplyDesc)
multiplyDesc.addressSpace = "threadgroup"
multiplyDesc.leadingDimensionA = "\(leadingBlockDimensionA)"
multiplyDesc.leadingDimensionB = "\(leadingBlockDimensionB)"
source += createMultiply(descriptor: multiplyDesc)
}
// Add the setup portion where the addresses are prepared.
do {
var blockBytesA = paddedBlockDimensionsA.M * paddedBlockDimensionsA.K
var blockBytesB = paddedBlockDimensionsB.K * paddedBlockDimensionsB.N
var blockBytesC = paddedBlockDimensionsC.M * paddedBlockDimensionsC.N
blockBytesA *= createPrecisionSize(memoryPrecisions.A)
blockBytesB *= createPrecisionSize(memoryPrecisions.B)
blockBytesC *= createPrecisionSize(memoryPrecisions.C)
threadgroupMemoryAllocation = max(blockBytesA + blockBytesB, blockBytesC)
source += """
// Metal function arguments.
//
// A: the left-hand side matrix
// - dimensions: M x K
// K x M (transposed)
// - memory precision: memA
// - register precision: regA
//
// B: the right-hand side matrix
// - dimensions: K x N
// N x K (transposed)
// - memory precision: memB
// - register precision: regB
//
// C: the output matrix, alternatively the dot product accumulator
// - dimensions: M x N
// - memory precision: memC
// - register precision: regC
//
// threadgroup_block: the chunk of threadgroup memory allocated at runtime
// - ideally 10 KB or less
// - precision: void/8-bit integer to make the pointer arithmetic more legible
kernel void gemm(device \(memoryNameA) *A [[buffer(0)]],
device \(memoryNameB) *B [[buffer(1)]],
device \(memoryNameC) *C [[buffer(2)]],
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
auto A_block = (threadgroup \(memoryNameA)*)(threadgroup_block);
auto B_block = (threadgroup \(memoryNameB)*)(threadgroup_block + \(blockBytesA));
ushort2 sid(sidx % \(splits.N), sidx / \(splits.N));
ushort2 morton_offset = morton_order(lane_id);
// Return early if the SIMD is out of bounds.
//
// There could be some threadgroups where the matrix edge cuts straight
// through the middle of the block. SIMDs on the right or bottom of the
// dividing line must be stopped from causing out-of-bounds accesses. This is
// the reason for the early exit.
uint M_offset = gid.y * M_group;
uint N_offset = gid.x * N_group;
{
if (M_offset + sid.y * \(registerM) >= M ||
N_offset + sid.x * \(registerN) >= N) {
return;
}
}
ushort2 offset_in_group(sid.x * \(registerN) + morton_offset.x,
sid.y * \(registerM) + morton_offset.y);
// Shift the matrix block within bounds, if possible.
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) {
M_offset -= M_shift;
}
if ((N_shift != 0) && (gid.x * N_group >= N_edge)) {
N_offset -= N_shift;
}
"""
}
// Add the setup of the accumulator.
do {
let arrayElementsC: UInt16 = (registerM / 8) * (registerN / 8)
source += """
simdgroup_matrix_storage<\(registerNameC)> C_sram[\(arrayElementsC)];
// Initialize the accumulator.
#pragma clang loop unroll(full)
for (ushort m = 0; m < \(registerM); m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < \(registerN); n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, \(registerN), origin);
*C = simdgroup_matrix_storage<\(registerNameC)>(0);
}
}
"""
}
// Add the matrix multiplication iterations.
//
// Async copies are required for correct behavior in edge cases. We attempt
// to execute most iterations without async copy, and only the necessary
// ones with async copy.
do {
var asyncIterationsStart: String
if descriptor.preferAsyncLoad {
asyncIterationsStart = "0"
} else {
asyncIterationsStart = "(K - (K % K_group))"
}
let paddedCeilingK = "(K + K_remainder_padded - K_remainder)"
source += """
// Perform the iterations where async copy is avoided.
for (uint k = 0; k < \(asyncIterationsStart); k += 8) {
uint2 A_offset(k, M_offset);
uint2 B_offset(N_offset, k);
A_offset += uint2(morton_offset.x, offset_in_group.y);
B_offset += uint2(offset_in_group.x, morton_offset.y);
auto A_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset(
A, \(leadingDimensionA), A_offset, A_trans);
auto B_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset(
B, \(leadingDimensionB), B_offset, B_trans);
simdgroup_matrix_storage<\(registerNameA)> A_sram[\(registerM / 8) * (8 / 8)];
simdgroup_matrix_storage<\(registerNameB)> B_sram[(8 / 8) * \(registerN / 8)];
multiply_accumulate(A_src, B_src,
A_sram, B_sram, C_sram, 0);
}
// Perform the iterations where async copy is used.
for (uint k = \(asyncIterationsStart); k < K; k += K_group) {
// Launch an async copy from device to threadgroup memory.
if (sidx == 0) {
uint2 A_offset(k, M_offset);
uint2 B_offset(N_offset, k);
auto A_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset(
A, \(leadingDimensionA), A_offset, A_trans);
auto B_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset(
B, \(leadingDimensionB), B_offset, B_trans);
ushort M_tile_dimension = min(uint(M_group), M - M_offset);
ushort N_tile_dimension = min(uint(N_group), N - N_offset);
ushort K_tile_dimension = min(uint(K_group), K - k);
ushort K_tile_padded = min(uint(K_group), \(paddedCeilingK) - k);
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension);
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension);
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension);
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded);
simdgroup_event events[2];
events[0].async_copy(A_block, \(leadingBlockDimensionA), A_tile_dst,
A_src, \(leadingDimensionA), A_tile_src, A_trans);
events[1].async_copy(B_block, \(leadingBlockDimensionB), B_tile_dst,
B_src, \(leadingDimensionB), B_tile_src, B_trans);
simdgroup_event::wait(2, events);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
ushort2 A_block_offset(morton_offset.x, offset_in_group.y);
ushort2 B_block_offset(offset_in_group.x, morton_offset.y);
auto A_block_src = simdgroup_matrix_storage<\(memoryNameA)>::apply_offset(
A_block, \(leadingBlockDimensionA), A_block_offset, A_trans);
auto B_block_src = simdgroup_matrix_storage<\(memoryNameB)>::apply_offset(
B_block, \(leadingBlockDimensionB), B_block_offset, B_trans);
simdgroup_matrix_storage<\(registerNameA)> A_sram[\(registerM / 8) * (K_group / 8)];
simdgroup_matrix_storage<\(registerNameB)> B_sram[(K_group / 8) * \(registerN / 8)];
#pragma clang loop unroll(full)
for (ushort k = 0; k < K_remainder_padded; k += 8) {
multiply_accumulate(A_block_src, B_block_src,
A_sram, B_sram, C_sram, k);
}
// Will there be any iterations after this one?
if (k + K_group < K) {
// If so, we haven't reached the edge of either input matrix yet.
#pragma clang loop unroll(full)
for (ushort k = K_remainder_padded; k < K_group; k += 8) {
multiply_accumulate(A_block_src, B_block_src,
A_sram, B_sram, C_sram, k);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
"""
}
// Add the cleanup portion where the accumulator is stored.
do {
var storeFunctionC: String
if memoryPrecisions.C == .BF16,
registerPrecisions.C == .FP32 {
storeFunctionC = "store_bfloat"
} else {
storeFunctionC = "store"
}
var condition: String
if preferAsyncStore {
condition = "false"
} else {
condition = "(M >= M_group) && (N >= N_group)"
}
source += """
if (\(condition)) {
// Fast path for matrices that qualify.
uint2 C_offset(N_offset + offset_in_group.x,
M_offset + offset_in_group.y);
auto C_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset(
C, N, C_offset);
// Write the accumulator to device memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < \(registerM); m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < \(registerN); n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, \(registerN), origin);
C->\(storeFunctionC)(C_dst, N, origin);
}
}
} else {
// Slow path for when memory must be handled more carefully.
auto C_block = (threadgroup \(memoryNameC)*)(threadgroup_block);
auto C_block_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset(
C_block, N_group, offset_in_group);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the accumulator to threadgroup memory.
#pragma clang loop unroll(full)
for (ushort m = 0; m < \(registerM); m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < \(registerN); n += 8) {
ushort2 origin(n, m);
auto C = get_sram(C_sram, \(registerN), origin);
C->\(storeFunctionC)(C_block_dst, N_group, origin);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Launch the async copy from threadgroup to device memory.
if (sidx == 0) {
uint2 C_offset(gid.x * N_group, gid.y * M_group);
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_dst = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset(
C, N, C_offset);
// If we shift successfully, the garbage zone moves from the bottom right
// to the top left.
if ((M_shift != 0) || (N_shift != 0)) {
ushort2 C_block_shift(0, 0);
if ((M_shift != 0) && (C_offset.y >= M_edge)) {
C_block_shift.y = M_shift;
}
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
C_block_shift.x = N_shift;
}
C_block = simdgroup_matrix_storage<\(memoryNameC)>::apply_offset(
C_block, N_group, C_block_shift);
}
simdgroup_event event;
event.async_copy(C_dst, N, C_tile, C_block, N_group, C_tile);
}
}
"""
}
// Add the final closing brace of the Metal function.
source += "}" + "\n"
// Compile the shader source.
library = try! device.makeLibrary(source: source, options: nil)
}
}
// MARK: - Header Sources
/// Create the source code for the 'metal\_simdgroup\_event' header.
///
/// I may have found the hardware bug with async copies on M1. If you shoot
/// off an async copy, you need to read from its contents later in the
/// the shader. Otherwise, something inside the hardware (like a
/// DispatchSemaphore) will be waiting indefinitely to be notified. The bug
/// is a bit flaky, and only shows up for certain problem configurations. The
/// side effects are catastrophic; the GPU might freeze up until the computer
/// reboots.
///
/// Workaround: if an async copy from device -> threadgroup is launched,
/// guarantee that both:
/// - The threadgroup will enter another `threadgroup_barrier` before the end of
/// the kernel.
/// - The results of the async copy will be read from. This means at least one
/// thread must dereference a pointer within the region of threadgroup memory.
func createMetalSimdgroupEvent() -> String {
// Return the source string.
return """
// -*- Metal -*-
//===-- metal_simdgroup_event ---------------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_EVENT
#define __METAL_SIMDGROUP_EVENT
// Invoking the generation of LLVM bitcode for async copies.
//
// %struct._simdgroup_event_t = type opaque
//
struct _simdgroup_event_t;
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
// i64, i64,
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
// <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong,
threadgroup void *, ulong, ulong, ulong2,
const device void *, ulong, ulong, ulong2,
long2, int)
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
// i64, i64,
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
// <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong,
device void *, ulong, ulong, ulong2,
const threadgroup void *, ulong, ulong, ulong2,
long2, int)
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: convergent nounwind
// declare void
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
// local_unnamed_addr #3
//
void __metal_wait_simdgroup_events(
int, thread _simdgroup_event_t**)
__asm("air.wait_simdgroup_events");
#pragma METAL internals : enable
namespace metal
{
enum class simdgroup_async_copy_clamp_mode {
clamp_to_zero = 0,
clamp_to_edge = 1
};
struct simdgroup_event {
METAL_FUNC simdgroup_event() thread {}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
threadgroup T *dst,
ushort dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const device T *src,
uint src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false,
simdgroup_async_copy_clamp_mode clamp_mode =
simdgroup_async_copy_clamp_mode::clamp_to_zero
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<threadgroup void *>(dst),
ushort(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const device void *>(src),
uint(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
static_cast<int>(clamp_mode));
}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
device T *dst,
uint dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const threadgroup T *src,
ushort src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<device void *>(dst),
uint(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const threadgroup void *>(src),
ushort(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
0);
}
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
__metal_wait_simdgroup_events(
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
}
private:
// Invoking the generation of LLVM bitcode for async copies.
//
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
//
thread _simdgroup_event_t* event;
};
} // namespace metal
#pragma METAL internals : disable
#endif // __METAL_SIMDGROUP_EVENT
"""
}
/// Create the source code for the 'metal\_simdgroup\_matrix\_storage' header.
func createMetalSimdgroupMatrixStorage() -> String {
// How this header spawning code was designed.
//
// Find the patterns between the load/store functions:
// - device has 'uint' elements_per_row
// - threadgroup has 'ushort' elements_per_row
// - both have 'ushort2' matrix_origin
//
// The origin is 'ushort2' because the 32-bit part of the address should have
// been applied previously during 'apply_offset'. The 16-bit part should be
// hard-coded into the assembly when the GEMM loop is unrolled.
//
// Transpose path:
// - load: reads two values; should split each one onto a separate line.
// - overwrites the value of *thread_elements() with a new vec<T, 2>
// - store: the two instructions are on two separate lines.
// - fetches from lane 0 or 1 of thread_elements()[0]
// - adds 0 or 1 to the hard-coded matrix_origin.x
//
// Address generation:
// - casts some intermediate address fragments to 'ulong' for 'device'
// - keeps all address fragments in 'ushort' for 'threadgroup'
enum AddressSpace {
case device
case threadgroup
var keyword: String {
switch self {
case .device: return "device"
case .threadgroup: return "threadgroup"
}
}
var offsetType: String {
switch self {
case .device: return "uint"
case .threadgroup: return "ushort"
}
}
}
enum Action {
case load
case store
}
struct MemoryAccessDescriptor {
var action: Action?
var addressSpace: AddressSpace?
var decodingBF16: Bool?
var indentationSpaceCount: Int = .zero
}
func createMemoryAccess(
descriptor: MemoryAccessDescriptor
) -> String {
guard let action = descriptor.action,
let addressSpace = descriptor.addressSpace,
let decodingBF16 = descriptor.decodingBF16 else {
fatalError("Descriptor was incomplete.")
}
let indentation = String(
repeating: " ", count: descriptor.indentationSpaceCount)
// Determine the arguments.
var arguments: [String] = []
func addPointerArgument(dataType: String) {
if action == .load {
arguments.append("const \(addressSpace.keyword) \(dataType) *src")
} else {
arguments.append("\(addressSpace.keyword) \(dataType) *dst")
}
}
if decodingBF16 {
addPointerArgument(dataType: "bfloat")
} else {
addPointerArgument(dataType: "U")
}
arguments.append("\(addressSpace.offsetType) elements_per_row")
arguments.append("ushort2 matrix_origin")
arguments.append("bool transpose_matrix = false")
// Create the warning comment.
var output: String = ""
if decodingBF16 {
output += "\(indentation)// WARNING: 'T' must be 'float'.\n"
} else {
output += "\(indentation)template <typename U>\n"
}
// Create the function signature.
output += "\(indentation)METAL_FUNC void"
if action == .load {
output += " load"
} else {
output += " store"
}
if decodingBF16 {
output += "_bfloat"
}
output += "("
for argumentID in arguments.indices {
let argument = arguments[argumentID]
output += argument
if argumentID < arguments.count - 1 {
output += ", "
}
}
output += ") {\n"
func createAddress(transposed: Bool, offset: Int) -> String {
let lineY = "\(addressSpace.offsetType)(matrix_origin.y)"
var lineX = "matrix_origin.x + \(offset)"
lineX = "\(addressSpace.offsetType)(\(lineX))"
if transposed {
return "\(lineX) * elements_per_row + \(lineY)"
} else {
return "\(lineY) * elements_per_row + \(lineX)"
}
}
func createTwoPartAccess(transposed: Bool) -> [String] {
// Generate the addresses.
var lines: [String] = []
for laneID in 0..<2 {
lines.append(
"\(addressSpace.offsetType) address\(laneID) = " +
createAddress(transposed: transposed, offset: laneID))
}
if action == .load {
if decodingBF16 {
lines.append("bfloat memoryForm0 = src[address0]")
lines.append("bfloat memoryForm1 = src[address1]")
} else {
lines.append("U memoryForm0 = src[address0]")
lines.append("U memoryForm1 = src[address1]")
}
}
if action == .load {
if decodingBF16 {
// Separate the loading logic from the decoding logic for clarity.
lines.append(
"")
// BF16 decoding logic.
lines.append(
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())")
lines.append(
"registerForm[1] = memoryForm0")
lines.append(
"registerForm[3] = memoryForm1")
lines.append(
"((thread bfloat4*)thread_elements())[0] = registerForm")
} else {
// Perform a type cast natively supported by the hardware.
lines.append("((thread T*)thread_elements())[0] = T(memoryForm0)")
lines.append("((thread T*)thread_elements())[1] = T(memoryForm1)")
}
} else {
if decodingBF16 {
// BF16 encoding logic.
lines.append(
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())")
lines.append(
"registerForm[2] = registerForm[1]")
} else {
// Type casts supported natively by the hardware.
lines.append("T registerForm0 = ((thread T*)thread_elements())[0]")
lines.append("T registerForm1 = ((thread T*)thread_elements())[1]")
}
}
if action == .store {
if decodingBF16 {
lines.append("dst[address0] = registerForm[2]")
lines.append("dst[address1] = registerForm[3]")
} else {
lines.append("dst[address0] = U(registerForm0)")
lines.append("dst[address1] = U(registerForm1)")
}
}
return lines
}
func createOnePartAccess() -> [String] {
var lines: [String] = []
do {
let address = createAddress(transposed: false, offset: 0)
lines.append("auto combinedAddress = \(address)")
}
if action == .load {
if decodingBF16 {
lines.append(
"bfloat2 memoryForm = " +
"*(const \(addressSpace.keyword) packed_bfloat2*)(src + combinedAddress)")
// Separate the loading logic from the decoding logic for clarity.
lines.append(
"")
// BF16 decoding logic.
lines.append(
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())")
lines.append(
"((thread float*)&registerForm)[1] = *(thread float*)(&memoryForm)")
lines.append(
"((thread bfloat*)&registerForm)[1] = memoryForm[0]")
lines.append(
"((thread bfloat4*)thread_elements())[0] = registerForm")
} else {
lines.append(
"vec<U, 2> memoryForm = " +
"*(const \(addressSpace.keyword) vec<U, 2>*)(src + combinedAddress)")
lines.append(
"*(thread_elements()) = vec<T, 2>(memoryForm)")
}
} else {
if decodingBF16 {
// BF16 encoding logic.
lines.append(
"bfloat4 registerForm = *(thread bfloat4*)(thread_elements())")
lines.append(
"registerForm[2] = registerForm[1]")
lines.append(
"float memoryForm = ((thread float*)&registerForm)[1]")
lines.append(
"*(\(addressSpace.keyword) float*)(dst + combinedAddress) = " +
"memoryForm")
} else {
lines.append(
"vec<T, 2> registerForm = *(thread_elements())")
lines.append(
"*(\(addressSpace.keyword) vec<U, 2>*)(dst + combinedAddress) = " +
"vec<U, 2>(registerForm)")
}
}
return lines
}
func addBlockContents(_ block: [String]) -> [String] {
block.map {
if $0.allSatisfy(\.isWhitespace) {
return " "
} else {
return " \($0);"
}
}
}
// Determine the lines of the 'if' block.
var body: [String] = []
body.append("if (transpose_matrix) {")
body += addBlockContents(createTwoPartAccess(transposed: true))
// Determine the lines of the 'else' block.
if decodingBF16 {
var blockContents: [String]
if action == .load {
blockContents = createOnePartAccess()
} else {
blockContents = createTwoPartAccess(transposed: false)
}
body.append("} else {")
body += addBlockContents(blockContents)
body.append("}")
} else {
body.append("} else if (elements_per_row % 2 != 0) {")
body += addBlockContents(createTwoPartAccess(transposed: false))
body.append("} else {")
body += addBlockContents(createOnePartAccess())
body.append("}")
}
// Create the function body.
for line in body {
output += "\(indentation) \(line)\n"
}
output += "\(indentation)}\n"
return output
}
// Add the first section of the shader.
var output: String = ""
output += """
// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE
#pragma METAL internals : enable
namespace metal
{
template <typename T>
struct simdgroup_matrix_storage {
typedef vec<T, 64> storage_type;
storage_type t;
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
return reinterpret_cast<thread vec<T, 2>*>(&t);
}
METAL_FUNC simdgroup_matrix_storage() thread = default;
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
*(this->thread_elements()) = thread_elements;
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
"""
var desc = MemoryAccessDescriptor()
desc.indentationSpaceCount = 4
for action in [Action.load, .store] {
for addressSpace in [AddressSpace.device, .threadgroup] {
for decodingBF16 in [false, true] {
desc.action = action
desc.addressSpace = addressSpace
desc.decodingBF16 = decodingBF16
output += createMemoryAccess(descriptor: desc)
output += "\n"
}
}
}
// Add the last section of the header.
output += """
template <typename U, typename V>
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
if (!accumulate) {
*(thread_elements()) = vec<T, 2>(0);
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
};
} // namespace metal
#pragma METAL internals : disable
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
"""
return output
}
// MARK: - Implementation of Configuration Selection Logic
/// A description of a dense matrix-matrix multiplication.
struct GEMMDescriptor {
/// The number of equally sized multiplications that run in parallel.
/// Batching is out of scope for the reference implementation. However, there
/// should be a guide for clients that wish to modify the shader, in ways
/// that increase the compute workload. For example, by batching the
/// multiplication of (sub)matrices located at arbitrary pointers in memory
/// (with potentially nonuniform stride or noncontiguous padding).
var batchDimension: Int = 1
/// The dimensions of the input and output matrices.
/// - Parameter M: Number of output columns.
/// - Parameter N: Number of output rows.
/// - Parameter K: Number of loop iterations for the dot products.
///
/// For all practical purposes, one can assume matrix dimensions are 32-bit.
/// I use this quite often in other code. The pointers themselves are 64-bit,
/// but the offsets between different elements are 32-bit. With 4-byte words,
/// this scheme could access up to 16 GB of memory - larger than any array
/// in any reasonable application. Handling larger allocations likely
/// requires consideration of more failure points than just integer
/// overflows.
var matrixDimensions: (M: UInt32, N: UInt32, K: UInt32)?
var memoryPrecisions: (
A: GEMMOperandPrecision, B: GEMMOperandPrecision, C: GEMMOperandPrecision)?
var transposeState: (A: Bool, B: Bool)?
}
struct GEMMKey: Equatable, Hashable {
var batchDimension: Int
var matrixDimensions: SIMD3<UInt32>
var memoryPrecisions: SIMD3<UInt16>
var transposeState: SIMD2<UInt8>
init(copying source: GEMMDescriptor) {
batchDimension = source.batchDimension
matrixDimensions = Self.createMatrixDimensions(source.matrixDimensions)
memoryPrecisions = GEMMKernelKey.createPrecisions(source.memoryPrecisions)
transposeState = GEMMKernelKey.createTransposeState(source.transposeState)
}
@_transparent // performance in -Ounchecked
static func createMatrixDimensions(
_ input: (UInt32, UInt32, UInt32)?
) -> SIMD3<UInt32> {
if let input {
return SIMD3(input.0, input.1, input.2)
} else {
return SIMD3(repeating: .max)
}
}
}
extension GEMMDescriptor: Hashable, Equatable {
static func == (lhs: GEMMDescriptor, rhs: GEMMDescriptor) -> Bool {
let lhsKey = GEMMKey(copying: lhs)
let rhsKey = GEMMKey(copying: rhs)
return lhsKey == rhsKey
}
func hash(into hasher: inout Hasher) {
let key = GEMMKey(copying: self)
hasher.combine(key)
}
}
extension GEMMKernelDescriptor {
/// Initialize the kernel descriptor using another descriptor, which just
/// specifies the problem size. Then, forget the information about problem
/// size. It will not be needed until the very far future, when the user
/// retrieves a `MTLLibrary` from the cache and sets some Metal function
/// constants.
///
/// One might initialize a `GEMMKernelDescriptor` this way whenever an
/// arbitrary matrix multiplication is requested. The generated descriptor
/// itself could be a key in the KV cache. With this shader cache design, you
/// must minimize the latency of actions like `MTLDevice` materialization and
/// core count queries.
///
/// Acceptable latency: no more than 1 μs per invocation.
init(descriptor: GEMMDescriptor) {
guard let matrixDimensions = descriptor.matrixDimensions,
let memoryPrecisions = descriptor.memoryPrecisions,
let transposeState = descriptor.transposeState else {
fatalError("Descriptor was incomplete.")
}
// Select the only GPU on an Apple silicon system.
//
// NOTE: To avoid potentially costly API calls, you may wish to cache the
// MTLDevice object or enter a previously created one. The core count
// could also be cached on macOS.
//
// Typical latency to initiate a Metal device, provided the function has
// been called numerous times prior:
// - macOS 14
// - Swift debug mode, Metal API validation on: ≥33 μs
// - Swift release mode, Metal API validation off: ≥38 μs
// - iOS 17
// - Swift debug mode, Metal API validation on: ≥0 μs
// - Swift release mode, Metal API validation off: ≥0 μs
let mtlDevice = MTLCreateSystemDefaultDevice()!
// Trim the device name to something easier to process.
//
// M1 Max: Apple M1 Max -> M1
// M4: Apple M4 GPU -> M4
func createDeviceName() -> String {
let deviceName = mtlDevice.name
var splits = deviceName.split(separator: " ").map(String.init)
splits.removeAll(where: { $0.starts(with: "Apple") })
splits.removeAll(where: { $0.starts(with: "GPU") })
// Iterate over the space-separated words.
var matchingSplitIDs: [UInt32] = []
for splitID in splits.indices {
// Screen out obvious non-candidates.
let split = splits[splitID]
guard split.starts(with: "A") || split.starts(with: "M") else {
continue
}
// Extract the second character.
guard split.count > 1 else {
continue
}
let secondCharacterInt8 = split.utf8CString[1]
let secondCharacterUInt32 = UInt32(secondCharacterInt8)
let secondCharacterUnicode = Unicode.Scalar(secondCharacterUInt32)!
let secondCharacter = Character(secondCharacterUnicode)
// If the second character is numeric, the candidate passes.
if secondCharacter.isWholeNumber {
matchingSplitIDs.append(UInt32(splitID))
}
}
guard matchingSplitIDs.count == 1 else {
fatalError("Failed to locate device name.")
}
let splitID = matchingSplitIDs[0]
return splits[Int(splitID)]
}
let deviceName = createDeviceName()
// Find the core count.
#if os(macOS)
// Typical latency to query IORegistry, provided the function has been
// called numerous times prior:
// - macOS 14
// - Swift debug mode, Metal API validation on: ≥9 μs
// - Swift release mode, Metal API validation off: ≥9 μs
let coreCount = findCoreCount()
#elseif os(iOS)
var coreCount: Int
if deviceName.starts(with: "A") {
if mtlDevice.supportsFamily(.apple9) {
coreCount = 6
} else {
coreCount = 5
}
} else {
coreCount = 10
}
#endif
// Select the register precisions.
var registerPrecisionA = memoryPrecisions.A
var registerPrecisionB = memoryPrecisions.B
var registerPrecisionC = GEMMOperandPrecision.FP32
if memoryPrecisions.A == .FP16,
memoryPrecisions.B == .FP16,
memoryPrecisions.C == .FP16 {
// If FP16 is causing accuracy issues, you can change this to FP32. Note
// that doing so cuts out a very important part of the performance
// spectrum. It is only FP16xFP16->FP16 that reaches peak performance.
// This statement applies to both the M1 and M3 architectures.
//
// FP16xFP16 into FP16 accumulator triggers this instruction:
// https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3232-L3244
//
// FP16xFP16/BF16xBF16 into FP32 accumulator triggers this instruction:
// https://github.com/dougallj/applegpu/blob/aeb81519159246d70c56d3f77adb4bc9cca7aa0d/applegpu.py#L3195-L3207
//
// No other input/output register types map to a native instruction.
//
// I would recommend changing the accumulator precision on a case-by-case
// (operation-by-operation) basis. Provide some mechanism in the high-level
// API, to control certain low-level features. Without harming execution
// latency and without imposing technical debt on the high-level API.
// Definitely NOT a global flag that forces all matrices to change from
// FP16 -> FP32.
registerPrecisionC = GEMMOperandPrecision.FP16
}
if !mtlDevice.supportsFamily(.apple9) {
if memoryPrecisions.A == .BF16 {
registerPrecisionA = .FP32
}
if memoryPrecisions.B == .BF16 {
registerPrecisionB = .FP32
}
}
// Set the properties of the 'GEMMKernelDescriptor' object.
self.memoryPrecisions = memoryPrecisions
if mtlDevice.supportsFamily(.apple9) {
self.preferAsyncLoad = false
} else {
self.preferAsyncLoad = true
}
self.registerPrecisions = (
registerPrecisionA,
registerPrecisionB,
registerPrecisionC)
if !mtlDevice.supportsFamily(.apple9) {
self.splits = (2, 2)
} else {
self.splits = (1, 1)
}
self.transposeState = transposeState
// Set the properties that deal with block size.
setBlockDimensions(
mtlDevice: mtlDevice,
coreCount: coreCount,
matrixDimensions: matrixDimensions,
batchDimension: descriptor.batchDimension)
}
// Implementation of the block size selection heuristic.
//
// This function initializes the 'blockDimensions' and
// 'paddedBlockDimensions' properties.
private mutating func setBlockDimensions(
mtlDevice: MTLDevice,
coreCount: Int,
matrixDimensions: (M: UInt32, N: UInt32, K: UInt32),
batchDimension: Int
) {
guard let memoryPrecisions,
let transposeState else {
fatalError("Some properties were not set.")
}
guard !mtlDevice.supportsFamily(.apple9) else {
self.blockDimensions = (32, 32, 8)
return
}
// Find the actual number of threadgroups, with a large block size.
func ceilDivide(_ target: UInt32, _ granularity: UInt16) -> UInt32 {
(target + UInt32(granularity) - 1) / UInt32(granularity)
}
var actualGroups: Int = 1
actualGroups *= Int(ceilDivide(matrixDimensions.M, 48))
actualGroups *= Int(ceilDivide(matrixDimensions.N, 48))
actualGroups *= Int(batchDimension)
// Does the kernel use 48x48x24xFP32 (9 KB) or 48x48x32xFP16/BF16 (6 KB)?
func requiresLargeAllocation(_ precision: GEMMOperandPrecision) -> Bool {
switch precision {
case .FP32: return true
case .FP16: return false
case .BF16: return false
}
}
var useLargeAllocation = false
if requiresLargeAllocation(memoryPrecisions.A) {
useLargeAllocation = true
}
if requiresLargeAllocation(memoryPrecisions.B) {
useLargeAllocation = true
}
if requiresLargeAllocation(memoryPrecisions.C) {
useLargeAllocation = true
}
// Branch on whether the allocation is large / target occupancy is low.
if useLargeAllocation {
let idealGroups = coreCount * 6
if actualGroups <= idealGroups {
self.blockDimensions = (32, 32, 32)
} else {
self.blockDimensions = (48, 48, 24)
// This is verified to be optimal for:
// - (memA, memB, memC) = (FP32, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP16, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP32)
// - (memA, memB, memC) = (FP16, FP32, FP16)
switch transposeState {
case (false, false):
self.paddedBlockDimensions = ((48, 24), (24, 48), (48, 48))
case (false, true):
let paddedBK = (memoryPrecisions.B == .FP32) ? UInt16(28) : 24
self.paddedBlockDimensions = ((48, 24), (paddedBK, 48), (48, 48))
case (true, false):
let paddedAM = (memoryPrecisions.A == .FP32) ? UInt16(52) : 56
self.paddedBlockDimensions = ((paddedAM, 24), (24, 48), (48, 48))
case (true, true):
let paddedAM = (memoryPrecisions.A == .FP32) ? UInt16(52) : 56
self.paddedBlockDimensions = ((paddedAM, 24), (24, 48), (48, 48))
}
}
} else {
let idealGroups = coreCount * 9
if actualGroups <= idealGroups {
blockDimensions = (32, 32, 32)
} else {
blockDimensions = (48, 48, 32)
}
}
}
}
#if os(macOS)
/// Finds the core count on macOS devices, using IORegistry.
///
/// Source: [AppleGPUInfo](https://github.com/philipturner/applegpuinfo)
///
/// This code was generated by GPT-4 a few days after launch (early 2023).
/// Since then, it has undergone extensive human review and real-world testing.
/// It proved that proto-AGI could be a practically useful tool, in this case
/// assisting with code creation.
func findCoreCount() -> Int {
// Create a matching dictionary with "AGXAccelerator" class name
let matchingDict = IOServiceMatching("AGXAccelerator")
// Get an iterator for matching services
var iterator: io_iterator_t = 0
do {
let io_registry_error =
IOServiceGetMatchingServices(
kIOMainPortDefault, matchingDict, &iterator)
guard io_registry_error == 0 else {
fatalError(
"Encountered IORegistry error code \(io_registry_error)")
}
}
// Get the first (and only) GPU entry from the iterator
let gpuEntry = IOIteratorNext(iterator)
// Check if the entry is valid
if gpuEntry == MACH_PORT_NULL {
fatalError(
"Error getting GPU entry at \(#file):\(#line - 5)")
}
// Release the iterator
IOObjectRelease(iterator)
// Get the "gpu-core-count" property from gpuEntry
let key = "gpu-core-count"
let options: IOOptionBits = 0 // No options needed
let gpuCoreCount = IORegistryEntrySearchCFProperty(
gpuEntry, kIOServicePlane, key as CFString, nil, options)
// Check if the property is valid
if gpuCoreCount == nil {
fatalError(
"Error getting gpu-core-count property at \(#file):\(#line - 6)")
}
// Cast the property to CFNumberRef
let gpuCoreCountNumber = gpuCoreCount as! CFNumber
// Check if the number type is sInt64
let type = CFNumberGetType(gpuCoreCountNumber)
if type != .sInt64Type {
fatalError(
"Error: gpu-core-count is not sInt64 at \(#file):\(#line - 3)")
}
// Get the value of the number as Int64
var value: Int64 = 0
let result = CFNumberGetValue(gpuCoreCountNumber, type, &value)
// Check for errors
if result == false {
fatalError(
" Error getting value of gpu-core-count at \(#file):\(#line - 5)")
}
return Int(value)
}
#endif
/// A reference implementation of shader caching.
///
/// One good design for a shader caching mechanism:
/// - Two key-value caches.
/// - The first caches `MTLLibrary` objects.
/// - Large latency
/// - Small number of combinatorial possibilities, likely to be shared by
/// matrices with a different size.
/// - Don't bother with serializing Metal binary archives to disk. You are
/// already utilizing the system-wide Metal shader cache.
/// - The second caches `MTLComputePipelineState` objects.
/// - Instantiations of the `MTLLibrary` with different function constants.
/// - Less latency than compiling from source, but still non-negligible. You
/// can't spawn a new PSO during every call to a matrix multiplication.
extension GEMMKernel {
/// WARNING: Not thread safe. But will the DSL interpreter even use
/// multithreading?
static var libraryCache: [
GEMMKernelDescriptor: GEMMKernel] = [:]
/// WARNING: Not thread safe. But will the DSL interpreter even use
/// multithreading?
static var pipelineCache: [
GEMMDescriptor: (GEMMKernel, MTLComputePipelineState)] = [:]
}
/// Implementation of the logic for choosing between 'device' and
/// 'threadgroup' store.
func retrieveGEMMKernel(
descriptor gemmDesc: GEMMDescriptor
) -> (GEMMKernel, MTLComputePipelineState) {
// Perform the early return before anything with high latency.
if let value = GEMMKernel.pipelineCache[gemmDesc] {
return value
}
func createKernel(descriptor: GEMMKernelDescriptor) -> GEMMKernel {
guard descriptor.preferAsyncStore != nil else {
fatalError("Prefer async store was not set.")
}
if let previous = GEMMKernel.libraryCache[descriptor] {
return previous
} else {
return GEMMKernel(descriptor: descriptor)
}
}
// Create a MTLDevice object, a function call with very high latency.
let device = MTLCreateSystemDefaultDevice()!
func createPipeline(library: MTLLibrary) -> MTLComputePipelineState {
// Set the function constants.
let constants = MTLFunctionConstantValues()
var M = gemmDesc.matrixDimensions!.M
var N = gemmDesc.matrixDimensions!.N
var K = gemmDesc.matrixDimensions!.K
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
let function = try! library.makeFunction(
name: "gemm", constantValues: constants)
let pipeline = try! device.makeComputePipelineState(function: function)
return pipeline
}
var kernelDesc = GEMMKernelDescriptor(descriptor: gemmDesc)
kernelDesc.device = device
if device.supportsFamily(.apple9) {
kernelDesc.preferAsyncStore = false
} else {
guard let blockDimensions = kernelDesc.blockDimensions else {
fatalError("Block dimensions were not set.")
}
if blockDimensions == (48, 48, 32) {
kernelDesc.preferAsyncStore = nil
} else {
kernelDesc.preferAsyncStore = true
}
}
var output: (GEMMKernel, MTLComputePipelineState)
if kernelDesc.preferAsyncStore != nil {
let kernel = createKernel(descriptor: kernelDesc)
let pipeline = createPipeline(library: kernel.library)
output = (kernel, pipeline)
GEMMKernel.libraryCache[kernelDesc] = kernel
} else {
var candidates: [
(kernelDesc: GEMMKernelDescriptor,
kernel: GEMMKernel,
pipeline: MTLComputePipelineState)
] = []
for candidateID in 0..<4 {
var blockDimensions: (M: UInt16, N: UInt16, K: UInt16)
var preferAsyncStore: Bool
switch candidateID {
case 0:
blockDimensions = (48, 48, 32)
preferAsyncStore = false
case 1:
blockDimensions = (48, 48, 40)
preferAsyncStore = false
case 2:
blockDimensions = (48, 48, 32)
preferAsyncStore = true
case 3:
blockDimensions = (48, 48, 40)
preferAsyncStore = true
default:
fatalError("This should never happen.")
}
// Set the data that's unique to this variant.
var newKernelDesc = kernelDesc
newKernelDesc.blockDimensions = blockDimensions
newKernelDesc.preferAsyncStore = preferAsyncStore
let kernel = createKernel(descriptor: newKernelDesc)
let pipeline = createPipeline(library: kernel.library)
candidates.append((newKernelDesc, kernel, pipeline))
GEMMKernel.libraryCache[newKernelDesc] = kernel
}
// Find the maximum occupancy.
var maximumOccupancy: Int = -1
for candidate in candidates {
let occupancy = candidate.pipeline.maxTotalThreadsPerThreadgroup
maximumOccupancy = max(maximumOccupancy, occupancy)
}
candidates.removeAll(where: {
$0.pipeline.maxTotalThreadsPerThreadgroup != maximumOccupancy
})
// Choose the highest-performing candidate.
let candidate = candidates.last!
kernelDesc = candidate.kernelDesc
output = (candidate.kernel, candidate.pipeline)
}
// Save the output to the cache.
GEMMKernel.pipelineCache[gemmDesc] = output
return output
}
// MARK: - Profiling
/// A continuous (integration) test of both correctness and performance. This
/// test completes with low latency (\<1 second) for rapid feedback during
/// iterative design.
///
/// Returns:
/// - lane 0: maximum achieved performance in GFLOPS
/// - lane 1: occupancy in threads/core
func profileProblemSize(
descriptor: GEMMDescriptor
) -> SIMD2<Int> {
guard let matrixDimensions = descriptor.matrixDimensions,
matrixDimensions.M == matrixDimensions.N,
matrixDimensions.M == matrixDimensions.K else {
fatalError("Matrix dimensions were invalid.")
}
let problemSize = Int(matrixDimensions.M)
// Allocate FP32 memory for the operands.
var A = [Float](repeating: .zero, count: problemSize * problemSize)
var B = [Float](repeating: .zero, count: problemSize * problemSize)
var C = [Float](repeating: .zero, count: problemSize * problemSize)
// Initialize A as the 2nd-order periodic Laplacian.
for diagonalID in 0..<problemSize {
let diagonalAddress = diagonalID * problemSize + diagonalID
A[diagonalAddress] = -2
let leftColumnID = (diagonalID + problemSize - 1) % problemSize
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID
A[leftSubDiagonalAddress] = 1
let rightColumnID = (diagonalID + problemSize + 1) % problemSize
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID
A[rightSubDiagonalAddress] = 1
}
// Initialize B to random numbers.
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = Float.random(in: 0..<1)
B[address] = entry
}
}
// Since the Laplacian is symmetric, we swap roles of the matrices to test
// transposition of the left-hand side.
//
// Note that the test cannot cover correctness of A and B transposition
// simultaneously. Instead, test the correctness in isolation
// (AB, AB^T, A^T B). Performance can be tested in all four permutations
// (AB, AB^T, A^T B, A^T B^T).
if descriptor.transposeState!.A {
swap(&A, &B)
}
// Initialize the context.
let device = MTLCreateSystemDefaultDevice()!
let commandQueue = device.makeCommandQueue()!
let context = (device: device, commandQueue: commandQueue)
func createBuffer(
_ originalData: [Float],
_ precision: GEMMOperandPrecision
) -> MTLBuffer {
// Add random numbers to expose out-of-bounds accesses.
var augmentedData = originalData
for _ in 0..<originalData.count {
let randomNumber = Float.random(in: -2...2)
augmentedData.append(randomNumber)
}
// Allocate enough memory to store everything in Float32.
let bufferSize = augmentedData.count * 4
let buffer = context.device.makeBuffer(length: bufferSize)!
// Copy the data into the buffer.
switch precision {
case .FP32:
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
for i in augmentedData.indices {
pointer[i] = augmentedData[i]
}
case .FP16:
let pointer = buffer.contents().assumingMemoryBound(to: Float16.self)
for i in augmentedData.indices {
pointer[i] = Float16(augmentedData[i])
}
case .BF16:
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self)
for i in augmentedData.indices {
let value32 = augmentedData[i].bitPattern
let value16 = unsafeBitCast(value32, to: SIMD2<UInt16>.self)[1]
pointer[i] = value16
}
}
return buffer
}
// Multiply A with B.
var maxGFLOPS: Int = .zero
var occupancy: Int = .zero
do {
// Generate the kernel.
let (kernel, pipeline) = retrieveGEMMKernel(descriptor: descriptor)
occupancy = pipeline.maxTotalThreadsPerThreadgroup
// Create the buffers.
let bufferA = createBuffer(A, descriptor.memoryPrecisions!.A)
let bufferB = createBuffer(B, descriptor.memoryPrecisions!.B)
let bufferC = createBuffer(C, descriptor.memoryPrecisions!.C)
// Profile the latency of matrix multiplication.
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Encode the GPU command.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(
Int(kernel.threadgroupMemoryAllocation), index: 0)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
for _ in 0..<duplicatedCommandCount {
func ceilDivide(_ target: Int, _ granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
let gridSize = MTLSize(
width: ceilDivide(problemSize, kernel.blockDimensions.N),
height: ceilDivide(problemSize, kernel.blockDimensions.M),
depth: 1)
let groupSize = MTLSize(
width: Int(kernel.threadgroupSize),
height: 1,
depth: 1)
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
// let latencyMicroseconds = Int(latency / 1e-6)
// print(latencyMicroseconds, "μs", gflops, "GFLOPS")
maxGFLOPS = max(maxGFLOPS, gflops)
}
// Copy the results to C.
do {
let precision = descriptor.memoryPrecisions!.C
let raw = bufferC.contents()
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
var entry32: Float
switch precision {
case .FP32:
let casted = raw.assumingMemoryBound(to: Float.self)
entry32 = casted[address]
case .FP16:
let casted = raw.assumingMemoryBound(to: Float16.self)
let entry16 = casted[address]
entry32 = Float(entry16)
case .BF16:
let casted = raw.assumingMemoryBound(to: UInt16.self)
let entry16 = casted[address]
let entry16x2 = SIMD2<UInt16>(.zero, entry16)
entry32 = unsafeBitCast(entry16x2, to: Float.self)
}
C[address] = entry32
}
}
}
}
// Choose an error threshold.
func createErrorThreshold(precision: GEMMOperandPrecision) -> Float {
switch precision {
case .FP32: return 1e-5
case .FP16: return 5e-3
case .BF16: return 5e-2
}
}
var errorThreshold: Float = 0
do {
let memoryPrecisions = descriptor.memoryPrecisions!
let thresholdA = createErrorThreshold(precision: memoryPrecisions.A)
let thresholdB = createErrorThreshold(precision: memoryPrecisions.B)
let thresholdC = createErrorThreshold(precision: memoryPrecisions.C)
errorThreshold = max(errorThreshold, thresholdA)
errorThreshold = max(errorThreshold, thresholdB)
errorThreshold = max(errorThreshold, thresholdC)
}
// Check the results.
var errorCount: Int = .zero
for m in 0..<problemSize {
for n in 0..<problemSize {
// Find the source row IDs.
let leftRowID = (m + problemSize - 1) % problemSize
let centerRowID = m
let rightRowID = (m + problemSize + 1) % problemSize
// Find the source scalars.
var leftSource: Float
var centerSource: Float
var rightSource: Float
if descriptor.transposeState!.A {
leftSource = A[leftRowID * problemSize + n]
centerSource = A[centerRowID * problemSize + n]
rightSource = A[rightRowID * problemSize + n]
} else if descriptor.transposeState!.B {
leftSource = B[n * problemSize + leftRowID]
centerSource = B[n * problemSize + centerRowID]
rightSource = B[n * problemSize + rightRowID]
} else {
leftSource = B[leftRowID * problemSize + n]
centerSource = B[centerRowID * problemSize + n]
rightSource = B[rightRowID * problemSize + n]
}
// Find the expected result.
let expected = leftSource - 2 * centerSource + rightSource
// Find the actual result.
var actual: Float
if descriptor.transposeState!.A {
actual = C[n * problemSize + m]
} else {
actual = C[m * problemSize + n]
}
// Report whether it is correct.
let error = (expected - actual).magnitude
if error > errorThreshold {
if errorCount < 10 {
print("error: \(error) / ~1.000")
errorCount += 1
}
}
}
}
return SIMD2(maxGFLOPS, occupancy)
}
/// The workspace for scripting everything during kernel design. Invoked by
/// SwiftUI in the iOS app for M4 testing.
func runApplication() {
struct TestDescriptor {
var precision: GEMMOperandPrecision?
var problemSize: Int?
var transposeState: (Bool, Bool)?
}
func runTest(descriptor: TestDescriptor) {
guard let precision = descriptor.precision,
let problemSize = descriptor.problemSize,
let transposeState = descriptor.transposeState else {
fatalError("Descriptor was incomplete.")
}
// Set up the kernel.
var gemmDesc = GEMMDescriptor()
let n = UInt32(problemSize)
gemmDesc.matrixDimensions = (M: n, N: n, K: n)
gemmDesc.memoryPrecisions = (precision, precision, precision)
gemmDesc.transposeState = descriptor.transposeState
// Test the kernel.
let statistic = profileProblemSize(descriptor: gemmDesc)
// Report the results.
do {
var repr = "\(problemSize)"
while repr.count < 4 {
repr = " " + repr
}
print("problemSize = \(repr)", terminator: " | ")
}
if transposeState.0 {
print("A^T", terminator: " ")
} else {
print("A ", terminator: " ")
}
if transposeState.1 {
print("B^T", terminator: " | ")
} else {
print("B ", terminator: " | ")
}
for laneID in [Int(1), Int(0)] {
var repr = "\(statistic[laneID])"
while repr.count < 4 {
repr = " " + repr
}
// Log the number to the console.
if laneID == 0 {
print(repr, terminator: " GFLOPS")
} else {
print(repr, terminator: " threads/core | ")
}
}
print("")
}
// Correctness tests.
do {
let problemSizes: [Int] = [
7, 8, 9, 10,
15, 16, 17, 18,
23, 24, 25,
31, 32, 33,
47, 48, 49,
63, 64, 65,
103, 104, 112,
126, 127, 128, 129,
130, 131,
135, 136, 137,
143, 144, 145,
151, 152, 153,
]
let transposeStates: [(Bool, Bool)] = [
(false, false),
(false, true),
(true, false),
]
print()
print("Correctness tests:")
for problemSize in problemSizes {
for transposeState in transposeStates {
var testDescriptor = TestDescriptor()
testDescriptor.precision = .FP32
testDescriptor.problemSize = problemSize
testDescriptor.transposeState = transposeState
runTest(descriptor: testDescriptor)
}
}
}
// My workspace. Edit this code to run actual tests.
do {
let transposeStates: [(Bool, Bool)] = [
(false, false),
(false, true),
(true, false),
(true, true),
]
// Working on investigating BF16 performance with large matrices.
print()
print("Performance tests:")
for problemSize in 1488...1489 {
for transposeState in transposeStates {
var testDescriptor = TestDescriptor()
testDescriptor.precision = .BF16
testDescriptor.problemSize = problemSize
testDescriptor.transposeState = transposeState
runTest(descriptor: testDescriptor)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment