Last active
May 29, 2024 14:57
-
-
Save philipturner/40052a700a448b9356b998154cd7e4cd to your computer and use it in GitHub Desktop.
Investigating the performance of low- and mixed-precision computations after dynamic caching
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// main.swift | |
// M4LowPrecisionMath | |
// | |
// Created by Philip Turner on 5/28/24. | |
// | |
import Metal | |
// Investigating the performance of low- and mixed-precision computations after | |
// dynamic caching. | |
// | |
// ========================================================================== // | |
// Introduction | |
// ========================================================================== // | |
// | |
// Mixed precision is a method for reducing the memory bandwidth of Float32 | |
// workloads. On some application-specific chips, there are dedicated FP16, | |
// BF16, or FP8 units with higher throughput than FP32. These units typically | |
// execute only FMA instructions or block matrix multiplications. For example, | |
// the transcendental activation function in a neural network would | |
// temporarily expand the 16-bit number to Float32. | |
// | |
// The AMX coprocessor can execute 16-bit instructions at twice the rate of | |
// 32-bit instructions. This low-precision capability would be extremely | |
// useful, reaching 50-100% performance of the GPU. However, it is not | |
// accessible through any public API. In addition, low-precision AI | |
// inference is easy to code on the GPU. For all practical purposes, the CPU | |
// is restricted to 32-bit types. | |
// | |
// This investigation should answer some questions about mixed precision on | |
// Apple GPUs. | |
// - Does the optimal block size change when the precision decreases from | |
// Float32 to Float16? | |
// - How well does the MSL built-in type for BFloat16 perform, compared to | |
// Float16? | |
// - How fast are the MPSMatrix and MPSGraph implementations of low-precision | |
// matrix multiplication? | |
// - Can MPS achieve double the performance by running two different | |
// matrix multiplications simultaneously, each with a different precision? | |
// - Can a kernel be engineered to exploit the FP16 and FP32 units | |
// simultaneously? | |
// | |
// ========================================================================== // | |
// Methods | |
// ========================================================================== // | |
// | |
// No async copies. | |
// 32 threads/threadgroup. | |
// | |
// ========================================================================== // | |
// Results (Raw Data) | |
// - Does the optimal block size change when the precision decreases from | |
// Float32 to Float16? | |
// - How well does the MSL built-in type for BFloat16 perform, compared to | |
// Float16? | |
// ========================================================================== // | |
// | |
// A, B, and C are Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 913 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2931 GFLOPS (32x32x8) | |
// - problemSize = 512 | 5342 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5463 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6160 GFLOPS (48x48x8) | |
// - problemSize = 896 | 6643 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 7596 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7676 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 7712 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 7747 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 8392 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1195 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1729 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2549 GFLOPS (32x32x8) | |
// - problemSize = 640 | 2983 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3036 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3044 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3074 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3123 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3134 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3167 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3174 GFLOPS (32x32x8) | |
// | |
// A, B, and C are Float16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 1390 GFLOPS (32x32x8) | |
// - problemSize = 384 | 3047 GFLOPS (32x32x8) | |
// - problemSize = 512 | 5578 GFLOPS (32x32x8) | |
// - problemSize = 640 | 6566 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6022 GFLOPS (32x32x8) | |
// - problemSize = 896 | 7255 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 7893 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 8905 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 8639 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 8099 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 9310 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 857 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1912 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2688 GFLOPS (32x32x8) | |
// - problemSize = 640 | 3306 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3347 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3383 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3428 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3471 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3491 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3517 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3532 GFLOPS (32x32x8) | |
// | |
// A, B, and C are BFloat16, through the MSL keyword 'bfloat'. | |
// | |
// M1 Max | |
// - problemSize = 256 | 604 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1992 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2822 GFLOPS (32x32x8) | |
// - problemSize = 640 | 3311 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3879 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3776 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 4127 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3822 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 4141 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 4130 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 4258 GFLOPS (32x32x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1171 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1709 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2736 GFLOPS (32x32x8) | |
// - problemSize = 640 | 2887 GFLOPS (32x32x8) | |
// - problemSize = 768 | 2945 GFLOPS (32x32x8) | |
// - problemSize = 896 | 2960 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3006 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3035 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3054 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3081 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3086 GFLOPS (32x32x8) | |
// | |
// A and B are BFloat16, C is Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 675 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2719 GFLOPS (32x32x8) | |
// - problemSize = 512 | 4840 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5954 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6260 GFLOPS (32x32x8) | |
// - problemSize = 896 | 6805 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 6590 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7707 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 7469 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 7296 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 8247 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1268 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1874 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2986 GFLOPS (32x32x8) | |
// - problemSize = 640 | 3254 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3310 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3333 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3348 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3416 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3438 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3463 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3477 GFLOPS (32x32x8) | |
// | |
// A and B are BFloat16, C is Float16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 1120 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2387 GFLOPS (32x32x8) | |
// - problemSize = 512 | 4629 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5736 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6455 GFLOPS (32x32x8) | |
// - problemSize = 896 | 6515 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 6643 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7459 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 6646 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 6756 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 7596 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1176 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1895 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2723 GFLOPS (32x32x8) | |
// - problemSize = 640 | 2917 GFLOPS (32x32x8) | |
// - problemSize = 768 | 2950 GFLOPS (32x32x8) | |
// - problemSize = 896 | 2969 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3002 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3033 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3054 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3079 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3088 GFLOPS (32x32x8) | |
// | |
// A and B are BFloat16 through MSL keyword. | |
// C is encoded from Float32 in registers to BFloat16 in RAM. | |
// | |
// M1 Max | |
// - problemSize = 256 | 825 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2726 GFLOPS (32x32x8) | |
// - problemSize = 512 | 4884 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5860 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6253 GFLOPS (32x32x8) | |
// - problemSize = 896 | 6823 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 6613 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7715 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 7458 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 7287 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 8263 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1274 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1885 GFLOPS (32x32x8) | |
// - problemSize = 512 | 3004 GFLOPS (32x32x8) | |
// - problemSize = 640 | 3252 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3290 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3314 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3338 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3387 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3414 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3438 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3452 GFLOPS (32x32x8) | |
// | |
// A and B are decoded from BFloat16 in RAM to Float32 in registers. | |
// C is encoded from Float32 in registers to BFloat16 in RAM. | |
// | |
// M1 Max | |
// - problemSize = 256 | 862 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2807 GFLOPS (32x32x8) | |
// - problemSize = 512 | 4082 GFLOPS (32x32x8) | |
// - problemSize = 640 | 6434 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6206 GFLOPS (32x32x8) | |
// - problemSize = 896 | 7056 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 7487 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7769 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 7686 GFLOPS (48x48x8) | |
// - problemSize = 1408 | 7760 GFLOPS (48x48x8) | |
// - problemSize = 1536 | 8684 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 256 | 1240 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1878 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2695 GFLOPS (32x32x8) | |
// - problemSize = 640 | 2910 GFLOPS (32x32x8) | |
// - problemSize = 768 | 2939 GFLOPS (32x32x8) | |
// - problemSize = 896 | 2939 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 2984 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3006 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3022 GFLOPS (32x32x8) | |
// - problemSize = 1408 | 3047 GFLOPS (32x32x8) | |
// - problemSize = 1536 | 3060 GFLOPS (32x32x8) | |
// | |
// ========================================================================== // | |
// Results (Summary) | |
// - Does the optimal block size change when the precision decreases from | |
// Float32 to Float16? | |
// - How well does the MSL built-in type for BFloat16 perform, compared to | |
// Float16? | |
// ========================================================================== // | |
// | |
// For practical arithmetic intensities (N < 1500), the optimal block size | |
// size is always 48x48 on M1, 32x32 on M4. This rule holds regardless of the | |
// precision, or whether async copies are being used. | |
// | |
// The statistics were augmented by a few more combinations of input/output | |
// precision. For brevity, the previous section does not include the raw data | |
// for these combinations. | |
// | |
// Accumulate in FP16. | |
// | |
// M1 Max | |
// - A = FP16, B = FP16, C = FP16 | 9310 GFLOPS | |
// - A = BF16, B = BF16, C = FP16 | 7596 GFLOPS | |
// - A = FP32, B = FP32, C = FP16 | 7869 GFLOPS | |
// | |
// M4 | |
// - A = FP16, B = FP16, C = FP16 | 3532 GFLOPS | |
// - A = BF16, B = BF16, C = FP16 | 3088 GFLOPS | |
// - A = FP32, B = FP32, C = FP16 | 2914 GFLOPS | |
// | |
// Accumulate in BF16. | |
// | |
// M1 Max | |
// - A = BF16, B = BF16, C = BF16 | 4258 GFLOPS (32x32x8) | |
// - A = BF16, B = BF16, C = eBF16 | 8263 GFLOPS | |
// - A = eBF16, B = eBF16, C = eBF16 | 8684 GFLOPS | |
// | |
// M4 | |
// - A = BF16, B = BF16, C = BF16 | 3086 GFLOPS | |
// - A = BF16, B = BF16, C = eBF16 | 3452 GFLOPS | |
// - A = eBF16, B = eBF16, C = eBF16 | 3060 GFLOPS | |
// | |
// Accumulate in FP32. | |
// | |
// M1 Max | |
// - A = FP16, B = FP16, C = FP32 | 8952 GFLOPS | |
// - A = BF16, B = BF16, C = FP32 | 8247 GFLOPS | |
// - A = eBF16, B = eBF16, C = FP32 | 8675 GFLOPS | |
// - A = FP32, B = FP32, C = FP32 | 8392 GFLOPS | |
// | |
// M4 | |
// - A = FP16, B = FP16, C = FP32 | 3477 GFLOPS | |
// - A = BF16, B = BF16, C = FP32 | 3477 GFLOPS | |
// - A = eBF16, B = eBF16, C = FP32 | 3080 GFLOPS | |
// - A = FP32, B = FP32, C = FP32 | 3174 GFLOPS | |
// | |
// ========================================================================== // | |
// Results (Raw Data) | |
// - How fast are the MPSMatrix and MPSGraph implementations of low-precision | |
// matrix multiplication? | |
// - Can MPS achieve double the performance by running two different | |
// matrix multiplications simultaneously, each with a different precision? | |
// ========================================================================== // | |
// | |
// For brevity, this section only shows kernels where all 3 operands have the | |
// same precision. | |
// | |
// MPSMatrix, Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 484 GFLOPS | |
// - problemSize = 384 | 1587 GFLOPS | |
// - problemSize = 512 | 3475 GFLOPS | |
// - problemSize = 640 | 5830 GFLOPS | |
// - problemSize = 768 | 5936 GFLOPS | |
// - problemSize = 896 | 6738 GFLOPS | |
// - problemSize = 1024 | 8050 GFLOPS | |
// - problemSize = 1152 | 7345 GFLOPS | |
// - problemSize = 1280 | 7444 GFLOPS | |
// - problemSize = 1408 | 7758 GFLOPS | |
// - problemSize = 1536 | 8055 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 460 GFLOPS | |
// - problemSize = 384 | 1036 GFLOPS | |
// - problemSize = 512 | 1569 GFLOPS | |
// - problemSize = 640 | 3028 GFLOPS | |
// - problemSize = 768 | 3087 GFLOPS | |
// - problemSize = 896 | 3086 GFLOPS | |
// - problemSize = 1024 | 3125 GFLOPS | |
// - problemSize = 1152 | 3152 GFLOPS | |
// - problemSize = 1280 | 3134 GFLOPS | |
// - problemSize = 1408 | 3150 GFLOPS | |
// - problemSize = 1536 | 3129 GFLOPS | |
// | |
// MPSMatrix, Float16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 509 GFLOPS | |
// - problemSize = 384 | 2285 GFLOPS | |
// - problemSize = 512 | 3260 GFLOPS | |
// - problemSize = 640 | 5398 GFLOPS | |
// - problemSize = 768 | 5647 GFLOPS | |
// - problemSize = 896 | 6231 GFLOPS | |
// - problemSize = 1024 | 7364 GFLOPS | |
// - problemSize = 1152 | 6774 GFLOPS | |
// - problemSize = 1280 | 6983 GFLOPS | |
// - problemSize = 1408 | 7122 GFLOPS | |
// - problemSize = 1536 | 7487 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 601 GFLOPS | |
// - problemSize = 384 | 1175 GFLOPS | |
// - problemSize = 512 | 2030 GFLOPS | |
// - problemSize = 640 | 3349 GFLOPS | |
// - problemSize = 768 | 3400 GFLOPS | |
// - problemSize = 896 | 3444 GFLOPS | |
// - problemSize = 1024 | 3479 GFLOPS | |
// - problemSize = 1152 | 3489 GFLOPS | |
// - problemSize = 1280 | 3504 GFLOPS | |
// - problemSize = 1408 | 3503 GFLOPS | |
// - problemSize = 1536 | 3496 GFLOPS | |
// | |
// MPSMatrix does not support BFloat16. | |
// | |
// MPSGraph, Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 439 GFLOPS | |
// - problemSize = 384 | 878 GFLOPS | |
// - problemSize = 512 | 2727 GFLOPS | |
// - problemSize = 640 | 4549 GFLOPS | |
// - problemSize = 768 | 5521 GFLOPS | |
// - problemSize = 896 | 6316 GFLOPS | |
// - problemSize = 1024 | 7000 GFLOPS | |
// - problemSize = 1152 | 7320 GFLOPS | |
// - problemSize = 1280 | 7596 GFLOPS | |
// - problemSize = 1408 | 7597 GFLOPS | |
// - problemSize = 1536 | 8274 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 590 GFLOPS | |
// - problemSize = 384 | 1105 GFLOPS | |
// - problemSize = 512 | 2051 GFLOPS | |
// - problemSize = 640 | 2798 GFLOPS | |
// - problemSize = 768 | 2937 GFLOPS | |
// - problemSize = 896 | 2986 GFLOPS | |
// - problemSize = 1024 | 3048 GFLOPS | |
// - problemSize = 1152 | 3100 GFLOPS | |
// - problemSize = 1280 | 3083 GFLOPS | |
// - problemSize = 1408 | 3108 GFLOPS | |
// - problemSize = 1536 | 3100 GFLOPS | |
// | |
// MPSGraph, Float16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 399 GFLOPS | |
// - problemSize = 384 | 1254 GFLOPS | |
// - problemSize = 512 | 2951 GFLOPS | |
// - problemSize = 640 | 4630 GFLOPS | |
// - problemSize = 768 | 5858 GFLOPS | |
// - problemSize = 896 | 6606 GFLOPS | |
// - problemSize = 1024 | 7051 GFLOPS | |
// - problemSize = 1152 | 7542 GFLOPS | |
// - problemSize = 1280 | 7942 GFLOPS | |
// - problemSize = 1408 | 8099 GFLOPS | |
// - problemSize = 1536 | 8418 GFLOPS | |
// | |
// M4, Level 0 | |
// - problemSize = 256 | 710 GFLOPS | |
// - problemSize = 384 | 1462 GFLOPS | |
// - problemSize = 512 | 2337 GFLOPS | |
// - problemSize = 640 | 2578 GFLOPS | |
// - problemSize = 768 | 2708 GFLOPS | |
// - problemSize = 896 | 2736 GFLOPS | |
// - problemSize = 1024 | 2820 GFLOPS | |
// - problemSize = 1152 | 2897 GFLOPS | |
// - problemSize = 1280 | 2954 GFLOPS | |
// - problemSize = 1408 | 3006 GFLOPS | |
// - problemSize = 1536 | 3041 GFLOPS | |
// | |
// M4, Level 1 | |
// - problemSize = 256 | 735 GFLOPS | |
// - problemSize = 384 | 1488 GFLOPS | |
// - problemSize = 512 | 2332 GFLOPS | |
// - problemSize = 640 | 2586 GFLOPS | |
// - problemSize = 768 | 2753 GFLOPS | |
// - problemSize = 896 | 2734 GFLOPS | |
// - problemSize = 1024 | 2826 GFLOPS | |
// - problemSize = 1152 | 3843 GFLOPS | |
// - problemSize = 1280 | 4404 GFLOPS | |
// - problemSize = 1408 | 4786 GFLOPS | |
// - problemSize = 1536 | 5204 GFLOPS | |
// - problemSize = 2048 | 5916 GFLOPS | |
// - problemSize = 3072 | 5921 GFLOPS | |
// - problemSize = 3584 | 6573 GFLOPS | |
// - problemSize = 4096 | 6631 GFLOPS | |
// - problemSize = 4608 | 3883 GFLOPS | |
// - problemSize = 5120 | 3755 GFLOPS | |
// | |
// MPSGraph, BFloat16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 325 GFLOPS | |
// - problemSize = 384 | 1132 GFLOPS | |
// - problemSize = 512 | 1987 GFLOPS | |
// - problemSize = 640 | 2352 GFLOPS | |
// - problemSize = 768 | 3311 GFLOPS | |
// - problemSize = 896 | 2996 GFLOPS | |
// - problemSize = 1024 | 3325 GFLOPS | |
// - problemSize = 1152 | 3691 GFLOPS | |
// - problemSize = 1280 | 3908 GFLOPS | |
// - problemSize = 1408 | 3937 GFLOPS | |
// - problemSize = 1536 | 4047 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 774 GFLOPS | |
// - problemSize = 384 | 1419 GFLOPS | |
// - problemSize = 512 | 2302 GFLOPS | |
// - problemSize = 640 | 2588 GFLOPS | |
// - problemSize = 768 | 2726 GFLOPS | |
// - problemSize = 896 | 2755 GFLOPS | |
// - problemSize = 1024 | 2836 GFLOPS | |
// - problemSize = 1152 | 2903 GFLOPS | |
// - problemSize = 1280 | 2950 GFLOPS | |
// - problemSize = 1408 | 2999 GFLOPS | |
// - problemSize = 1536 | 3041 GFLOPS | |
// | |
// MPSGraph, simultaneous Float16 and Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 630 GFLOPS | |
// - problemSize = 384 | 1971 GFLOPS | |
// - problemSize = 512 | 4434 GFLOPS | |
// - problemSize = 640 | 5283 GFLOPS | |
// - problemSize = 768 | 6228 GFLOPS | |
// - problemSize = 896 | 6859 GFLOPS | |
// - problemSize = 1024 | 7330 GFLOPS | |
// - problemSize = 1152 | 7611 GFLOPS | |
// - problemSize = 1280 | 7936 GFLOPS | |
// - problemSize = 1408 | 7890 GFLOPS | |
// - problemSize = 1536 | 8331 GFLOPS | |
// | |
// M4, Level 0 | |
// - problemSize = 256 | 911 GFLOPS | |
// - problemSize = 384 | 1820 GFLOPS | |
// - problemSize = 512 | 2501 GFLOPS | |
// - problemSize = 640 | 2629 GFLOPS | |
// - problemSize = 768 | 2721 GFLOPS | |
// - problemSize = 896 | 2856 GFLOPS | |
// - problemSize = 1024 | 2940 GFLOPS | |
// - problemSize = 1152 | 2994 GFLOPS | |
// - problemSize = 1280 | 3031 GFLOPS | |
// - problemSize = 1408 | 3065 GFLOPS | |
// - problemSize = 1536 | 3074 GFLOPS | |
// | |
// M4, Level 1 | |
// - problemSize = 256 | 952 GFLOPS | |
// - problemSize = 384 | 1935 GFLOPS | |
// - problemSize = 512 | 2515 GFLOPS | |
// - problemSize = 640 | 2636 GFLOPS | |
// - problemSize = 768 | 2749 GFLOPS | |
// - problemSize = 896 | 2859 GFLOPS | |
// - problemSize = 1024 | 2956 GFLOPS | |
// - problemSize = 1152 | 3461 GFLOPS | |
// - problemSize = 1280 | 3584 GFLOPS | |
// - problemSize = 1408 | 3759 GFLOPS | |
// - problemSize = 1536 | 3834 GFLOPS | |
// - problemSize = 2048 | 4100 GFLOPS | |
// - problemSize = 3072 | 4114 GFLOPS | |
// - problemSize = 3584 | 4238 GFLOPS | |
// - problemSize = 4096 | 4261 GFLOPS | |
// - problemSize = 4608 | 3374 GFLOPS | |
// - problemSize = 5120 | 3276 GFLOPS | |
// | |
// MPSGraph, simultaneous Float16 and BFloat16. | |
// | |
// M1 Max | |
// - problemSize = 256 | 536 GFLOPS | |
// - problemSize = 384 | 1755 GFLOPS | |
// - problemSize = 512 | 2812 GFLOPS | |
// - problemSize = 640 | 3322 GFLOPS | |
// - problemSize = 768 | 4420 GFLOPS | |
// - problemSize = 896 | 4215 GFLOPS | |
// - problemSize = 1024 | 4609 GFLOPS | |
// - problemSize = 1152 | 5028 GFLOPS | |
// - problemSize = 1280 | 5294 GFLOPS | |
// - problemSize = 1408 | 5338 GFLOPS | |
// - problemSize = 1536 | 5472 GFLOPS | |
// | |
// M4, Level 0 | |
// - problemSize = 256 | 933 GFLOPS | |
// - problemSize = 384 | 1899 GFLOPS | |
// - problemSize = 512 | 2429 GFLOPS | |
// - problemSize = 640 | 2540 GFLOPS | |
// - problemSize = 768 | 2639 GFLOPS | |
// - problemSize = 896 | 2734 GFLOPS | |
// - problemSize = 1024 | 2834 GFLOPS | |
// - problemSize = 1152 | 2909 GFLOPS | |
// - problemSize = 1280 | 2963 GFLOPS | |
// - problemSize = 1408 | 3013 GFLOPS | |
// - problemSize = 1536 | 3053 GFLOPS | |
// | |
// M4, Level 1 | |
// - problemSize = 256 | 949 GFLOPS | |
// - problemSize = 384 | 1976 GFLOPS | |
// - problemSize = 512 | 2453 GFLOPS | |
// - problemSize = 640 | 2552 GFLOPS | |
// - problemSize = 768 | 2629 GFLOPS | |
// - problemSize = 896 | 2736 GFLOPS | |
// - problemSize = 1024 | 2853 GFLOPS | |
// - problemSize = 1152 | 3288 GFLOPS | |
// - problemSize = 1280 | 3484 GFLOPS | |
// - problemSize = 1408 | 3658 GFLOPS | |
// - problemSize = 1536 | 3647 GFLOPS | |
// - problemSize = 2048 | 4126 GFLOPS | |
// - problemSize = 3072 | 4252 GFLOPS | |
// - problemSize = 3584 | 4429 GFLOPS | |
// - problemSize = 4096 | 4476 GFLOPS | |
// - problemSize = 4608 | 3624 GFLOPS | |
// - problemSize = 5120 | 3568 GFLOPS | |
// | |
// MPSGraph, simultaneous BFloat16 and Float32. | |
// | |
// M1 Max | |
// - problemSize = 256 | 547 GFLOPS | |
// - problemSize = 384 | 1764 GFLOPS | |
// - problemSize = 512 | 2966 GFLOPS | |
// - problemSize = 640 | 3346 GFLOPS | |
// - problemSize = 768 | 4448 GFLOPS | |
// - problemSize = 896 | 4204 GFLOPS | |
// - problemSize = 1024 | 4613 GFLOPS | |
// - problemSize = 1152 | 4999 GFLOPS | |
// - problemSize = 1280 | 5222 GFLOPS | |
// - problemSize = 1408 | 5223 GFLOPS | |
// - problemSize = 1536 | 5438 GFLOPS | |
// | |
// M4, Level 0 | |
// - problemSize = 256 | 936 GFLOPS | |
// - problemSize = 384 | 1886 GFLOPS | |
// - problemSize = 512 | 2499 GFLOPS | |
// - problemSize = 640 | 2632 GFLOPS | |
// - problemSize = 768 | 2749 GFLOPS | |
// - problemSize = 896 | 2852 GFLOPS | |
// - problemSize = 1024 | 2942 GFLOPS | |
// - problemSize = 1152 | 2999 GFLOPS | |
// - problemSize = 1280 | 3018 GFLOPS | |
// - problemSize = 1408 | 3062 GFLOPS | |
// - problemSize = 1536 | 3076 GFLOPS | |
// | |
// M4, Level 1 | |
// - problemSize = 256 | 1017 GFLOPS | |
// - problemSize = 384 | 1905 GFLOPS | |
// - problemSize = 512 | 2534 GFLOPS | |
// - problemSize = 640 | 2642 GFLOPS | |
// - problemSize = 768 | 2742 GFLOPS | |
// - problemSize = 896 | 2861 GFLOPS | |
// - problemSize = 1024 | 2955 GFLOPS | |
// - problemSize = 1152 | 3003 GFLOPS | |
// - problemSize = 1280 | 3021 GFLOPS | |
// - problemSize = 1408 | 3059 GFLOPS | |
// - problemSize = 1536 | 3075 GFLOPS | |
// | |
// ========================================================================== // | |
// Results (Raw Data) | |
// - Can a kernel be engineered to exploit the FP16 and FP32 units | |
// simultaneously? | |
// ========================================================================== // | |
// | |
// I tried running a kernel with two execution paths: one was a standard FP32 | |
// GEMM, the other a standard FP16 GEMM. This did not provide any speedup on | |
// M4. I tried again with various mixtures of precisions, including BF16. Still | |
// no speedup. I am starting to doubt Apple's claim about 2x performance for | |
// workloads with both FP16 and FP32. | |
// | |
// As a final attempt, I wrote a kernel where the accumulator is half-FP16, | |
// half-FP32. All operands are FP16 in memory. The FP32 part of the accumulator | |
// decompresses some FP16 values in memory to FP32 before multiplication. | |
// | |
// M1 Max | |
// - problemSize = 256 | 816 GFLOPS | |
// - problemSize = 384 | 2798 GFLOPS | |
// - problemSize = 512 | 5289 GFLOPS | |
// - problemSize = 640 | 6374 GFLOPS | |
// - problemSize = 768 | 6201 GFLOPS | |
// - problemSize = 896 | 6494 GFLOPS | |
// - problemSize = 1024 | 7790 GFLOPS | |
// - problemSize = 1152 | 5764 GFLOPS | |
// - problemSize = 1280 | 6813 GFLOPS | |
// - problemSize = 1408 | 7409 GFLOPS | |
// - problemSize = 1536 | 7225 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 1224 GFLOPS | |
// - problemSize = 384 | 1766 GFLOPS | |
// - problemSize = 512 | 2617 GFLOPS | |
// - problemSize = 640 | 3057 GFLOPS | |
// - problemSize = 768 | 3093 GFLOPS | |
// - problemSize = 896 | 3135 GFLOPS | |
// - problemSize = 1024 | 3141 GFLOPS | |
// - problemSize = 1152 | 3188 GFLOPS | |
// - problemSize = 1280 | 3203 GFLOPS | |
// - problemSize = 1408 | 3229 GFLOPS | |
// - problemSize = 1536 | 3243 GFLOPS | |
// | |
// This still does not achieve the 2x performance boost I had theorized. One | |
// very last stretch: make the LHS be stored as both FP16 and FP32 in memory. | |
// When data needs to be accumulated in FP32, it will read values as FP32 with | |
// no decompression. That should ensure the compiler doesn't convert desired | |
// FP32 instructions into undesired FP16 instructions. | |
// | |
// M1 Max | |
// - problemSize = 256 | 863 GFLOPS | |
// - problemSize = 384 | 2812 GFLOPS | |
// - problemSize = 512 | 5146 GFLOPS | |
// - problemSize = 640 | 6263 GFLOPS | |
// - problemSize = 768 | 6879 GFLOPS | |
// - problemSize = 896 | 6488 GFLOPS | |
// - problemSize = 1024 | 7870 GFLOPS | |
// - problemSize = 1152 | 5959 GFLOPS | |
// - problemSize = 1280 | 6883 GFLOPS | |
// - problemSize = 1408 | 7155 GFLOPS | |
// - problemSize = 1536 | 7195 GFLOPS | |
// | |
// M4 | |
// - problemSize = 256 | 792 GFLOPS | |
// - problemSize = 384 | 1319 GFLOPS | |
// - problemSize = 512 | 2440 GFLOPS | |
// - problemSize = 640 | 3019 GFLOPS | |
// - problemSize = 768 | 3054 GFLOPS | |
// - problemSize = 896 | 3060 GFLOPS | |
// - problemSize = 1024 | 3096 GFLOPS | |
// - problemSize = 1152 | 3148 GFLOPS | |
// - problemSize = 1280 | 3150 GFLOPS | |
// - problemSize = 1408 | 3184 GFLOPS | |
// - problemSize = 1536 | 3189 GFLOPS | |
// | |
// ========================================================================== // | |
// Discussion | |
// ========================================================================== // | |
// | |
// In both architectures, FP16 achieves better performance than FP32/BF16 no | |
// matter the calculation. FP32 can get very close, provided the operands are | |
// streamed from L1 as 16-bit types (to minimize the bandwidth bottleneck). | |
// Lower-precision data types do not improve performance much by increasing | |
// occupancy. Otherwise, we would see 16-bit types being optimal at larger | |
// block sizes than 32-bit types. | |
// | |
// MPS is also under-optimized for 16-bit data types on M3/M4. MPSGraph does | |
// not exceed 3000 GFLOPS on the M4 for either FP16 or BF16. It does not exceed | |
// 3100 GFLOPS for FP32. MPSMatrix does have a performance improvement for | |
// half, but it crashes on bfloat. | |
// | |
// GFLOPS | FP16 | BF16 | FP32 | | |
// ------------- | ----- | ----- | ----- | | |
// MPSGraph | 3000 | 3000 | 3100 | | |
// MPSMatrix | 3500 | error | 3100 | | |
// Custom Kernel | 3500 | 3500* | 3100 | | |
// | |
// *Only achieved when the accumulator is FP32. Otherwise, BF16 performance is | |
// 3000 GFLOPS. | |
// | |
// ===== The attempt to achieve 7000 GFLOPS on M4 ===== | |
// | |
// This did not work. The M3/M4 architecture is basically the same as M1/M2, in | |
// terms of ALU throughput. The BF16 support is mediocre, only implemented as | |
// an optimization to BF16 -> FP32 decoding. Furthermore, BF16 only reaches | |
// maximum performance when the accumulator is FP32. I can get similar | |
// performance by just emulating BF16 (within a factor of 1.1x). | |
// | |
// I did achieve 6500 GFLOPS at MPSGraph optimization level 1. Only when all of | |
// the data was Float16. That is because of the well-known delegation of | |
// galactic FP16 matrix multiplies to the neural engine. Not because it is | |
// getting ~7000 GFLOPS out of the GPU. | |
// | |
// ===== MPS BFloat16 performance is dismal on M1 ===== | |
// | |
// GFLOPS | FP16 | BF16 | FP32 | | |
// ------------- | ----- | ----- | ----- | | |
// MPSGraph | 8100 | 4000 | 8300 | | |
// MPSMatrix | 7500 | error | 8100 | | |
// Custom Kernel | 9300 | 8700* | 8400 | | |
// | |
// *Only achieved when the A/B inputs are decoded into Float32 registers | |
// before multiplying. This is technically BF15, as rounding error makes the | |
// final bit end up as garbage. | |
func runApplication() { | |
print("Hello, console.") | |
profileProblemSize(64) | |
#if false | |
profileProblemSize(256) | |
profileProblemSize(384) | |
profileProblemSize(512) | |
profileProblemSize(640) | |
profileProblemSize(768) | |
profileProblemSize(896) | |
profileProblemSize(1024) | |
profileProblemSize(1152) | |
profileProblemSize(1280) | |
profileProblemSize(1408) | |
profileProblemSize(1536) | |
#endif | |
#if false | |
profileProblemSize(2048) | |
profileProblemSize(3072) | |
profileProblemSize(3072 + 512) | |
profileProblemSize(4096) | |
profileProblemSize(4096 + 512) | |
profileProblemSize(5120) | |
#endif | |
} | |
// MARK: - Shader Sources | |
let metalSimdgroupEvent: String = """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_event ---------------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_EVENT | |
#define __METAL_SIMDGROUP_EVENT | |
#if !defined(__HAVE_SIMDGROUP_FUTURE__) | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %struct._simdgroup_event_t = type opaque | |
// | |
struct _simdgroup_event_t; | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_1d( | |
ulong, ulong, threadgroup void *, const device void *, ulong) | |
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_1d( | |
ulong, ulong, device void *, const threadgroup void *, ulong) | |
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p3i8.p1i8( | |
// i64, i64, i8 addrspace(3)* nocapture writeonly, | |
// i64, i64, <2 x i64>, i8 addrspace(1)* nocapture readonly, | |
// i64, i64, <2 x i64>, <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, threadgroup void *, | |
ulong, ulong, ulong2, const device void *, | |
ulong, ulong, ulong2, long2, int) | |
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p1i8.p3i8( | |
// i64, i64, i8 addrspace(1)* nocapture writeonly, | |
// i64, i64, <2 x i64>, i8 addrspace(3)* nocapture readonly, | |
// i64, i64, <2 x i64>, <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, device void *, | |
ulong, ulong, ulong2, const threadgroup void *, | |
ulong, ulong, ulong2, long2, int) | |
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: convergent nounwind | |
// declare void | |
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) | |
// local_unnamed_addr #3 | |
// | |
void __metal_wait_simdgroup_events( | |
int, thread _simdgroup_event_t**) | |
__asm("air.wait_simdgroup_events"); | |
#endif | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
#if !defined(__HAVE_SIMDGROUP_FUTURE__) | |
enum class simdgroup_async_copy_clamp_mode { | |
clamp_to_zero = 0, | |
clamp_to_edge = 1 | |
}; | |
#endif | |
struct simdgroup_event { | |
METAL_FUNC simdgroup_event() thread {} | |
template <typename T> | |
METAL_FUNC void async_copy(threadgroup T *dst, const device T *src, ulong n_elements) thread { | |
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), reinterpret_cast<const device void *>(src), n_elements); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(device T *dst, const threadgroup T *src, ulong n_elements) thread { | |
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), reinterpret_cast<const threadgroup void *>(src), n_elements); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(threadgroup T *dst, ushort dst_elements_per_row, ushort2 dst_tile_dimensions, const device T *src, uint src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), ushort(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const device void *>(src), uint(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), static_cast<int>(clamp_mode)); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(device T *dst, uint dst_elements_per_row, ushort2 dst_tile_dimensions, const threadgroup T *src, ushort src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), uint(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const threadgroup void *>(src), ushort(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), 0); | |
} | |
METAL_FUNC static void wait(int count, thread simdgroup_event *events) { | |
#if defined(__HAVE_SIMDGROUP_FUTURE__) | |
__metal_wait_simdgroup_events(count, reinterpret_cast<thread __metal_simdgroup_event_t*>(events)); | |
#else | |
__metal_wait_simdgroup_events(count, reinterpret_cast<thread _simdgroup_event_t**>(events)); | |
#endif | |
} | |
private: | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } | |
// | |
#if defined(__HAVE_SIMDGROUP_FUTURE__) | |
__metal_simdgroup_event_t event; | |
#else | |
thread _simdgroup_event_t* event; | |
#endif | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif // __METAL_SIMDGROUP_EVENT | |
""" | |
let metalSimdgroupMatrixStorage: String = """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_matrix_storage ------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE | |
#define __METAL_SIMDGROUP_MATRIX_STORAGE | |
// Contains C++ symbols accessible to a developer through automatic code | |
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard | |
// Library for consistency with other Metal code. | |
#if defined(__HAVE_SIMDGROUP_MATRIX__) | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
template <typename T> | |
struct simdgroup_matrix_storage { | |
typedef vec<T, 64> storage_type; | |
storage_type t; | |
METAL_FUNC thread vec<T, 2>* thread_elements() thread { | |
return reinterpret_cast<thread vec<T, 2>*>(&t); | |
} | |
METAL_FUNC simdgroup_matrix_storage() thread = default; | |
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread { | |
*(this->thread_elements()) = thread_elements; | |
} | |
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) { | |
// https://patents.google.com/patent/US11256518B2 | |
ushort lane_id = thread_index_in_simdgroup; | |
ushort quad_id = lane_id / 4; | |
constexpr ushort QUADRANT_SPAN_M = 4; | |
constexpr ushort THREADS_PER_QUADRANT = 8; | |
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; | |
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); | |
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; | |
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 | |
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 | |
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; | |
return ushort2(N_in_simd, M_in_simd); | |
} | |
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; | |
} else { | |
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; | |
} | |
} | |
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + matrix_origin.x * elements_per_row + matrix_origin.y; | |
} else { | |
return src + matrix_origin.y * elements_per_row + matrix_origin.x; | |
} | |
} | |
// WARNING: All load and store functions assume the X dimension is divisible by 2. | |
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void load_bfloat(const device bfloat *src, uint elements_per_row, ushort2 matrix_origin) { | |
auto src_ptr = src + (matrix_origin.y * elements_per_row + matrix_origin.x); | |
reinterpret_cast<thread ushort4*>(thread_elements())->yw = | |
*reinterpret_cast<const device ushort2*>(src_ptr); | |
} | |
METAL_FUNC void load_half(const device half *src, uint elements_per_row, ushort2 matrix_origin) { | |
auto src_ptr = src + (matrix_origin.y * elements_per_row + matrix_origin.x); | |
*thread_elements() = vec<T, 2>( | |
*reinterpret_cast<const device half2*>(src_ptr)); | |
} | |
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0]; | |
} | |
} | |
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1]; | |
} | |
} | |
METAL_FUNC void store_bfloat(device bfloat *dst, uint elements_per_row, ushort2 matrix_origin) { | |
auto dst_ptr = dst + (matrix_origin.y * elements_per_row + matrix_origin.x); | |
*reinterpret_cast<device ushort2*>(dst_ptr) = | |
reinterpret_cast<thread ushort4*>(thread_elements())->yw; | |
} | |
METAL_FUNC void store_half(device half *dst, uint elements_per_row, ushort2 matrix_origin) { | |
auto dst_ptr = dst + (matrix_origin.y * elements_per_row + matrix_origin.x); | |
*reinterpret_cast<device half2*>(dst_ptr) = half2(*thread_elements()); | |
} | |
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0]; | |
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
template <typename U, typename V> | |
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) { | |
if (!accumulate) { | |
*(thread_elements()) = vec<T, 2>(0); | |
} | |
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type()); | |
} | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif | |
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE | |
""" | |
let GEMM = """ | |
// | |
// GEMM.metal | |
// MetalFlashAttention | |
// | |
// Created by Philip Turner on 6/23/23. | |
// | |
#include <metal_stdlib> | |
\(metalSimdgroupMatrixStorage) | |
using namespace metal; | |
// MARK: - Function Constants | |
// Dimensions of each matrix. | |
constant uint M [[function_constant(0)]]; | |
constant uint N [[function_constant(1)]]; | |
constant uint K [[function_constant(2)]]; | |
// Whether each matrix is transposed. | |
constant bool A_trans [[function_constant(10)]]; | |
constant bool B_trans [[function_constant(11)]]; | |
constant uint A_leading_dim = A_trans ? M : K; | |
constant uint B_leading_dim = B_trans ? K : N; | |
constant ushort M_simd [[function_constant(200)]]; | |
constant ushort N_simd [[function_constant(201)]]; | |
// Elide work on the edge when matrix dimension < SRAM block dimension. | |
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd); | |
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd); | |
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd; | |
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; | |
// MARK: - Utilities | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[M_simd][8] | |
return sram + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[8][N_simd] | |
return sram + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// C_sram[M_simd][N_simd] | |
return sram + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram, | |
device half *C, bool m_is_edge, bool n_is_edge) | |
{ | |
const ushort m_start = (m_is_edge) ? M_modulo : 0; | |
const ushort n_start = (n_is_edge) ? N_modulo : 0; | |
const ushort m_end = (m_is_edge) ? M_simd : M_modulo; | |
const ushort n_end = (n_is_edge) ? N_simd : N_modulo; | |
#pragma clang loop unroll(full) | |
for (ushort m = m_start; m < m_end / 2; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = n_start; n < n_end; n += 8) { | |
ushort2 origin(n, m); | |
C_sram(sram, origin)->store_half(C, N, origin); | |
} | |
} | |
} | |
template <typename T> | |
METAL_FUNC void store_accumulator_2(thread simdgroup_matrix_storage<T> *sram, | |
device half *C, bool m_is_edge, bool n_is_edge) | |
{ | |
const ushort m_start = (m_is_edge) ? M_modulo : 0; | |
const ushort n_start = (n_is_edge) ? N_modulo : 0; | |
const ushort m_end = (m_is_edge) ? M_simd : M_modulo; | |
const ushort n_end = (n_is_edge) ? N_simd : N_modulo; | |
#pragma clang loop unroll(full) | |
for (ushort m = m_start / 2; m < m_end; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = n_start; n < n_end; n += 8) { | |
ushort2 origin(n, m); | |
C_sram(sram, origin)->store(C, N, origin); | |
} | |
} | |
} | |
// MARK: - Kernels | |
kernel void hgemm(device half *A [[buffer(0)]], | |
device half *B [[buffer(1)]], | |
device half *C [[buffer(2)]], | |
device float *A32 [[buffer(3)]], | |
uint3 gid [[threadgroup_position_in_grid]], | |
ushort lane_id [[thread_index_in_simdgroup]]) | |
{ | |
simdgroup_matrix_storage<float> A_sram_allocation[1024]; | |
simdgroup_matrix_storage<float> B_sram_allocation[1024]; | |
simdgroup_matrix_storage<float> C_sram_allocation[1024]; | |
simdgroup_matrix_storage<half> A_sram_allocation_2[1024]; | |
simdgroup_matrix_storage<half> C_sram_allocation_2[1024]; | |
ushort2 offset_in_simd = simdgroup_matrix_storage<half>::offset(lane_id); | |
ushort2 A_offset(0, gid.y * M_simd); | |
ushort2 B_offset(gid.x * N_simd, 0); | |
A_offset += offset_in_simd; | |
B_offset += offset_in_simd; | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded / 2; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto C = C_sram(C_sram_allocation, ushort2(n, m)); | |
*C = simdgroup_matrix_storage<float>(0); | |
} | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = M_padded / 2; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto C = C_sram(C_sram_allocation_2, ushort2(n, m)); | |
*C = simdgroup_matrix_storage<half>(0); | |
} | |
} | |
for (uint k = 0; k < K; k += 8) { | |
auto A32_src = simdgroup_matrix_storage<float>::apply_offset( | |
A32, A_leading_dim, uint2(A_offset), A_trans); | |
auto A_src = simdgroup_matrix_storage<half>::apply_offset( | |
A, A_leading_dim, uint2(A_offset), A_trans); | |
auto B_src = simdgroup_matrix_storage<half>::apply_offset( | |
B, B_leading_dim, uint2(B_offset), B_trans); | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded / 2; m += 8) { | |
ushort2 origin(0, m); | |
auto A = A_sram(A_sram_allocation, origin); | |
A->load(A32_src, A_leading_dim, origin); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = M_padded / 2; m < M_padded; m += 8) { | |
ushort2 origin(0, m); | |
auto A = A_sram(A_sram_allocation_2, origin); | |
A->load(A_src, A_leading_dim, origin); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, 0); | |
auto B = B_sram(B_sram_allocation, origin); | |
B->load_half(B_src, B_leading_dim, origin); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded / 2; m += 8) { | |
auto A = A_sram(A_sram_allocation, ushort2(0, m)); | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto B = B_sram(B_sram_allocation, ushort2(n, 0)); | |
auto C = C_sram(C_sram_allocation, ushort2(n, m)); | |
C->multiply(*A, *B); | |
} | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = M_padded / 2; m < M_padded; m += 8) { | |
auto A = A_sram(A_sram_allocation_2, ushort2(0, m)); | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto B = B_sram(B_sram_allocation, ushort2(n, 0)); | |
auto C = C_sram(C_sram_allocation_2, ushort2(n, m)); | |
C->multiply(*A, *B); | |
} | |
} | |
A_offset.x += 8; | |
B_offset.y += 8; | |
} | |
// WARNING: M and N must be divisible by 8. | |
{ | |
uint2 matrix_origin = uint2(B_offset.x, A_offset.y); | |
auto C_src = simdgroup_matrix_storage<half>::apply_offset( | |
C, N, matrix_origin); | |
store_accumulator(C_sram_allocation, C_src, false, false); | |
store_accumulator_2(C_sram_allocation_2, C_src, false, false); | |
const uint M_edge_floor = M - M % M_simd; | |
const uint N_edge_floor = N - N % N_simd; | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(C_sram_allocation, C_src, true, false); | |
store_accumulator_2(C_sram_allocation_2, C_src, true, false); | |
} | |
if (matrix_origin.x < N_edge_floor) { | |
store_accumulator(C_sram_allocation, C_src, false, true); | |
store_accumulator_2(C_sram_allocation_2, C_src, false, true); | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(C_sram_allocation, C_src, true, true); | |
store_accumulator_2(C_sram_allocation_2, C_src, true, true); | |
} | |
} | |
} | |
} | |
""" | |
// MARK: - Script | |
func profileProblemSize(_ problemSize: Int) { | |
var A = [Float](repeating: .zero, count: problemSize * problemSize) | |
var B = [Float](repeating: .zero, count: problemSize * problemSize) | |
var C = [Float](repeating: .zero, count: problemSize * problemSize) | |
// Initialize A as the 2nd-order periodic Laplacian. | |
for diagonalID in 0..<problemSize { | |
let diagonalAddress = diagonalID * problemSize + diagonalID | |
A[diagonalAddress] = -2 | |
let leftColumnID = (diagonalID + problemSize - 1) % problemSize | |
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID | |
A[leftSubDiagonalAddress] = 1 | |
let rightColumnID = (diagonalID + problemSize + 1) % problemSize | |
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID | |
A[rightSubDiagonalAddress] = 1 | |
} | |
// Initialize B to random numbers. | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = Float.random(in: 0..<1) | |
B[address] = entry | |
} | |
} | |
do { | |
// Initialize the context. | |
let context = MTLContext() | |
let library = try! context.device.makeLibrary(source: GEMM, options: nil) | |
// Set the function constants. | |
let constants = MTLFunctionConstantValues() | |
var M: Int = problemSize | |
var N: Int = problemSize | |
var K: Int = problemSize | |
var transpose: Bool = false | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
constants.setConstantValue(&transpose, type: .bool, index: 10) | |
constants.setConstantValue(&transpose, type: .bool, index: 11) | |
var M_simd: UInt16 = 32 | |
var N_simd: UInt16 = 32 | |
constants.setConstantValue(&M_simd, type: .ushort, index: 200) | |
constants.setConstantValue(&N_simd, type: .ushort, index: 201) | |
let function = try! library.makeFunction( | |
name: "hgemm", constantValues: constants) | |
let pipeline = try! context.device | |
.makeComputePipelineState(function: function) | |
func ceilDivide(target: Int, granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(target: N, granularity: N_simd), | |
height: ceilDivide(target: M, granularity: M_simd), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: 32, | |
height: 1, | |
depth: 1) | |
// Create the buffers. | |
var bufferDesc = MTLBufferDescriptor() | |
bufferDesc.context = context | |
bufferDesc.problemSize = problemSize | |
bufferDesc.data = A | |
bufferDesc.dataType = MTLDataType.half | |
let bufferA = createMTLBuffer(descriptor: bufferDesc) | |
bufferDesc.data = B | |
bufferDesc.dataType = MTLDataType.half | |
let bufferB = createMTLBuffer(descriptor: bufferDesc) | |
bufferDesc.data = C | |
bufferDesc.dataType = MTLDataType.half | |
let bufferC = createMTLBuffer(descriptor: bufferDesc) | |
// Create the second buffer for A. | |
bufferDesc.data = A | |
bufferDesc.dataType = MTLDataType.float | |
let bufferA32 = createMTLBuffer(descriptor: bufferDesc) | |
// Profile the latency of matrix multiplication. | |
var bestStatistics: SIMD2<Int> = .init(.max, .zero) | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Execute the operation. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
encoder.setBuffer(bufferA32, offset: 0, index: 3) | |
for _ in 0..<duplicatedCommandCount { | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize) | |
} | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds) | |
bestStatistics[1] = max(bestStatistics[1], gflops) | |
} | |
// Report the results. | |
print("//", terminator: " ") | |
print("- problemSize =", problemSize, terminator: " ") | |
print("|", bestStatistics[1], "GFLOPS", terminator: " ") | |
print() | |
// Copy the results to C. | |
do { | |
let rawPointer = bufferC.contents() | |
let castedPointer = rawPointer.assumingMemoryBound(to: Float16.self) | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = castedPointer[address] | |
// Truncate the output's precision to 16-bit, while keeping the | |
// memory format of Float32. | |
// let bitPattern = UInt32(entry) << 16 | |
// C[address] = Float(bitPattern: bitPattern) | |
// Read an accumulator stored as Float16 or Float32. | |
C[address] = Float(entry) | |
} | |
} | |
} | |
} | |
// Check the results. | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
// Find the source row IDs. | |
let leftRowID = (m + problemSize - 1) % problemSize | |
let centerRowID = m | |
let rightRowID = (m + problemSize + 1) % problemSize | |
// Find the source values. | |
let leftSource = B[leftRowID * problemSize + n] | |
let centerSource = B[centerRowID * problemSize + n] | |
let rightSource = B[rightRowID * problemSize + n] | |
// Find the expected and actual values. | |
let expected = leftSource - 2 * centerSource + rightSource | |
let actual = C[m * problemSize + n] | |
// Report the results. | |
let error = (expected - actual).magnitude | |
if error > 5e-2 { | |
print("error: \(error) / ~1.000") | |
} | |
} | |
} | |
} | |
// MARK: - Utilities | |
struct MTLContext { | |
var device: MTLDevice | |
var commandQueue: MTLCommandQueue | |
init() { | |
device = MTLCreateSystemDefaultDevice()! | |
commandQueue = device.makeCommandQueue()! | |
} | |
} | |
struct MTLBufferDescriptor { | |
var context: MTLContext? | |
var data: [Float]? | |
var problemSize: Int? | |
var dataType: MTLDataType? | |
} | |
func createMTLBuffer( | |
descriptor: MTLBufferDescriptor | |
) -> MTLBuffer { | |
guard let context = descriptor.context, | |
let data = descriptor.data, | |
let problemSize = descriptor.problemSize, | |
let dataType = descriptor.dataType else { | |
fatalError("Descriptor was invalid.") | |
} | |
// Check the size of the input array. | |
guard data.count == problemSize * problemSize else { | |
fatalError("Input was not a square matrix.") | |
} | |
// Allocate enough memory to store everything in Float32. | |
let buffer = context.device.makeBuffer(length: data.count * 4)! | |
// Copy the data into the buffer. | |
switch dataType { | |
case .half: | |
// Compress the precision of the data. | |
var compressedData = [Float16](repeating: .zero, count: data.count) | |
for elementID in data.indices { | |
let entry = data[elementID] | |
compressedData[elementID] = Float16(entry) | |
} | |
// Copy the data. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float16.self) | |
pointer.initialize(from: compressedData, count: compressedData.count) | |
case .bfloat: | |
// Compress the precision of the data. | |
var compressedData = [UInt16](repeating: .zero, count: data.count) | |
for elementID in data.indices { | |
let entry = data[elementID] | |
// Truncate the input's precision to 16-bit, while keeping the memory | |
// format of Float32. | |
let bitPattern = entry.bitPattern | |
compressedData[elementID] = UInt16(bitPattern >> 16) | |
} | |
// Copy the data. | |
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self) | |
pointer.initialize(from: compressedData, count: compressedData.count) | |
case .float: | |
// Copy the data. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
pointer.initialize(from: data, count: data.count) | |
default: | |
fatalError("Unrecognized data type.") | |
} | |
// Return the buffer. | |
return buffer | |
} | |
#if false | |
struct MPSMatrixStorage { | |
var buffer: MTLBuffer | |
var matrix: MPSMatrix | |
init(descriptor: MPSMatrixStorageDescriptor) { | |
guard let context = descriptor.context, | |
let data = descriptor.data, | |
let problemSize = descriptor.problemSize, | |
let dataType = descriptor.dataType else { | |
fatalError("Descriptor was invalid.") | |
} | |
// Create the buffer | |
buffer = createMTLBuffer( | |
context: context, | |
data: data, | |
problemSize: problemSize, | |
dataType: dataType) | |
// Set the descriptor properties. | |
let elementStride = (dataType == .float32) ? 4 : 2 | |
let matrixDesc = MPSMatrixDescriptor( | |
rows: problemSize, | |
columns: problemSize, | |
rowBytes: problemSize * elementStride, | |
dataType: dataType) | |
// Create the matrix object. | |
matrix = MPSMatrix(buffer: buffer, descriptor: matrixDesc) | |
} | |
} | |
struct MPSTensorStorageDescriptor { | |
var context: MTLContext? | |
var data: [Float]? | |
var problemSize: Int? | |
var dataType: MPSDataType? | |
} | |
struct MPSTensorStorage { | |
var buffer: MTLBuffer | |
var tensorData: MPSGraphTensorData | |
init(descriptor: MPSTensorStorageDescriptor) { | |
guard let context = descriptor.context, | |
let data = descriptor.data, | |
let problemSize = descriptor.problemSize, | |
let dataType = descriptor.dataType else { | |
fatalError("Descriptor was invalid.") | |
} | |
// Create the buffer | |
buffer = createMTLBuffer( | |
context: context, | |
data: data, | |
problemSize: problemSize, | |
dataType: dataType) | |
// Create the tensor data. | |
let elementStride = (dataType == .float32) ? 4 : 2 | |
tensorData = MPSGraphTensorData( | |
buffer, | |
shape: [NSNumber(value: problemSize), NSNumber(value: problemSize)], | |
dataType: dataType, | |
rowBytes: problemSize * elementStride) | |
} | |
} | |
#endif | |
// MARK: - Reference Code | |
#if false | |
func testCustomShaderPerformance() { | |
// Initialize the context. | |
let context = MTLContext() | |
let library = try! context.device.makeLibrary(source: GEMM, options: nil) | |
// Set the function constants. | |
let constants = MTLFunctionConstantValues() | |
var M: Int = problemSize | |
var N: Int = problemSize | |
var K: Int = problemSize | |
var transpose: Bool = false | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
constants.setConstantValue(&transpose, type: .bool, index: 10) | |
constants.setConstantValue(&transpose, type: .bool, index: 11) | |
var M_simd: UInt16 = 32 | |
var N_simd: UInt16 = 32 | |
constants.setConstantValue(&M_simd, type: .ushort, index: 200) | |
constants.setConstantValue(&N_simd, type: .ushort, index: 201) | |
let function = try! library.makeFunction( | |
name: "hgemm", constantValues: constants) | |
let pipeline = try! context.device | |
.makeComputePipelineState(function: function) | |
func ceilDivide(target: Int, granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(target: N, granularity: N_simd), | |
height: ceilDivide(target: M, granularity: M_simd), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: 32, | |
height: 1, | |
depth: 1) | |
// Create the buffers. | |
func createBuffer(data: [Float], lowPrecision: Bool) -> MTLBuffer { | |
// Check the size of the input array. | |
guard data.count == problemSize * problemSize else { | |
fatalError("Input was not a square matrix.") | |
} | |
// Allocate enough memory to store everything in Float32. | |
let buffer = context.device.makeBuffer(length: data.count * 4)! | |
// Branch on whether the data is low-precision. | |
if lowPrecision { | |
// Compress the precision of the data. | |
var compressedData = [UInt16](repeating: .zero, count: data.count) | |
for elementID in data.indices { | |
let entry = data[elementID] | |
// Truncate the input's precision to 16-bit, while keeping the memory | |
// format of Float32. | |
// let bitPattern = entry.bitPattern | |
// compressedData[elementID] = UInt16(bitPattern >> 16) | |
// Write an input stored as Float16. | |
let bitPattern = Float16(entry).bitPattern | |
compressedData[elementID] = bitPattern | |
} | |
// Copy the data. | |
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self) | |
pointer.initialize(from: compressedData, count: compressedData.count) | |
} else { | |
// Copy the data. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
pointer.initialize(from: data, count: data.count) | |
} | |
// Return the buffer object. | |
return buffer | |
} | |
let bufferA = createBuffer(data: A, lowPrecision: false) | |
let bufferB = createBuffer(data: B, lowPrecision: false) | |
let bufferC = createBuffer(data: C, lowPrecision: false) | |
// Profile the latency of matrix multiplication. | |
var bestStatistics: SIMD2<Int> = .init(.max, .zero) | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Execute the operation. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
for _ in 0..<duplicatedCommandCount { | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize) | |
} | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds) | |
bestStatistics[1] = max(bestStatistics[1], gflops) | |
} | |
} | |
#endif | |
#if false | |
func testMPSMatrixPerformance() { | |
// Initialize the context. | |
let context = MTLContext() | |
// Initialize the matrices. | |
var matrixStorageDesc = MPSMatrixStorageDescriptor() | |
matrixStorageDesc.context = context | |
matrixStorageDesc.problemSize = problemSize | |
matrixStorageDesc.dataType = .float32 | |
matrixStorageDesc.data = A | |
let matrixA = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
matrixStorageDesc.data = B | |
let matrixB = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
matrixStorageDesc.data = C | |
let matrixC = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
// Initialize the multiplication object. | |
let multiplication = MPSMatrixMultiplication( | |
device: context.device, | |
resultRows: problemSize, | |
resultColumns: problemSize, | |
interiorColumns: problemSize) | |
multiplication.leftMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
multiplication.rightMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
multiplication.resultMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
// Profile the latency of the matrix multiplication. | |
var bestStatistics: SIMD2<Int> = .init(.max, .zero) | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Execute the operation. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
for _ in 0..<duplicatedCommandCount { | |
multiplication.encode( | |
commandBuffer: commandBuffer, | |
leftMatrix: matrixA.matrix, | |
rightMatrix: matrixB.matrix, | |
resultMatrix: matrixC.matrix) | |
} | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds) | |
bestStatistics[1] = max(bestStatistics[1], gflops) | |
} | |
} | |
#endif | |
#if false | |
func testMPSGraphPerformance() { | |
// Initialize the context. | |
let context = MTLContext() | |
// Initialize the tensors. | |
var tensorStorageDesc = MPSTensorStorageDescriptor() | |
tensorStorageDesc.context = context | |
tensorStorageDesc.problemSize = problemSize | |
tensorStorageDesc.dataType = .float32 | |
tensorStorageDesc.data = A | |
let tensorStorageA = MPSTensorStorage(descriptor: tensorStorageDesc) | |
let tensorStorageA2 = MPSTensorStorage(descriptor: tensorStorageDesc) | |
tensorStorageDesc.data = B | |
let tensorStorageB = MPSTensorStorage(descriptor: tensorStorageDesc) | |
let tensorStorageB2 = MPSTensorStorage(descriptor: tensorStorageDesc) | |
tensorStorageDesc.data = C | |
tensorStorageDesc.dataType = .bFloat16 | |
let tensorStorageC = MPSTensorStorage(descriptor: tensorStorageDesc) | |
tensorStorageDesc.dataType = .float16 | |
let tensorStorageC2 = MPSTensorStorage(descriptor: tensorStorageDesc) | |
// Create the graph executable. | |
var executable: MPSGraphExecutable | |
do { | |
let graph = MPSGraph() | |
// Program the matrix multiplication through the DSL. | |
let shape = [NSNumber(value: problemSize), NSNumber(value: problemSize)] | |
let tensorA32 = graph.placeholder(shape: shape, name: "Operand A (FP32)") | |
let tensorB32 = graph.placeholder(shape: shape, name: "Operand B (FP32)") | |
let tensorA16 = graph.cast( | |
tensorA32, to: .bFloat16, name: "Operand A (BF16)") | |
let tensorB16 = graph.cast( | |
tensorB32, to: .bFloat16, name: "Operand B (BF16)") | |
let tensorC = graph.matrixMultiplication( | |
primary: tensorA16, secondary: tensorB16, name: "Output") | |
// Program the second matrix multiplication through the DSL. | |
let tensorA232 = graph.placeholder(shape: shape, name: "Operand A (2nd)") | |
let tensorB232 = graph.placeholder(shape: shape, name: "Operand B (2nd)") | |
let tensorA216 = graph.cast( | |
tensorA232, to: .float16, name: "Operand A (2nd, FP16)") | |
let tensorB216 = graph.cast( | |
tensorB232, to: .float16, name: "Operand B (2nd, FP16)") | |
let tensorC2 = graph.matrixMultiplication( | |
primary: tensorA216, secondary: tensorB216, name: "Output (2nd)") | |
// Define the data type of each input to the graph. | |
let shapedType = MPSGraphShapedType(shape: shape, dataType: .float32) | |
let feeds: [MPSGraphTensor : MPSGraphShapedType] = [ | |
tensorA32: shapedType, | |
tensorB32: shapedType, | |
tensorA232: shapedType, | |
tensorB232: shapedType, | |
] | |
let targetTensors: [MPSGraphTensor] = [ | |
tensorC, | |
tensorC2, | |
] | |
// Set the compilation descriptor. | |
let compilationDesc = MPSGraphCompilationDescriptor() | |
compilationDesc.optimizationLevel = .level1 | |
compilationDesc.waitForCompilationCompletion = true | |
// Create the graph executable. | |
let mpsGraphDevice = MPSGraphDevice( | |
mtlDevice: context.device) | |
executable = graph.compile( | |
with: mpsGraphDevice, | |
feeds: feeds, | |
targetTensors: targetTensors, | |
targetOperations: nil, | |
compilationDescriptor: compilationDesc) | |
// Specialize the executable. | |
executable.specialize( | |
with: mpsGraphDevice, | |
inputTypes: [ | |
shapedType, | |
shapedType, | |
shapedType, | |
shapedType, | |
], | |
compilationDescriptor: compilationDesc) | |
} | |
// Profile the latency of matrix multiplication. | |
var bestStatistics: SIMD2<Int> = .init(.max, .zero) | |
let trialCount = (problemSize >= 3072) ? 5 : 15 | |
let duplicatedCommandCount = (problemSize >= 3072) ? 5 : 20 | |
for _ in 0..<trialCount { | |
// Create the execution descriptor. | |
let executionDesc = MPSGraphExecutableExecutionDescriptor() | |
executionDesc.waitUntilCompleted = false | |
// Execute the operation. | |
let start = CACurrentMediaTime() | |
let commandBuffer = MPSCommandBuffer(from: context.commandQueue) | |
for _ in 0..<duplicatedCommandCount { | |
executable.encode( | |
to: commandBuffer, | |
inputs: [ | |
tensorStorageA.tensorData, | |
tensorStorageB.tensorData, | |
tensorStorageA2.tensorData, | |
tensorStorageB2.tensorData, | |
], | |
results: [ | |
tensorStorageC.tensorData, | |
tensorStorageC2.tensorData, | |
], | |
executionDescriptor: executionDesc) | |
} | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
let end = CACurrentMediaTime() | |
// Determine the time taken. | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations *= 2 // FP16 and FP32 simultaneously | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds) | |
bestStatistics[1] = max(bestStatistics[1], gflops) | |
} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment