Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active May 29, 2024 14:57
Show Gist options
  • Save philipturner/40052a700a448b9356b998154cd7e4cd to your computer and use it in GitHub Desktop.
Save philipturner/40052a700a448b9356b998154cd7e4cd to your computer and use it in GitHub Desktop.
Investigating the performance of low- and mixed-precision computations after dynamic caching
//
// main.swift
// M4LowPrecisionMath
//
// Created by Philip Turner on 5/28/24.
//
import Metal
// Investigating the performance of low- and mixed-precision computations after
// dynamic caching.
//
// ========================================================================== //
// Introduction
// ========================================================================== //
//
// Mixed precision is a method for reducing the memory bandwidth of Float32
// workloads. On some application-specific chips, there are dedicated FP16,
// BF16, or FP8 units with higher throughput than FP32. These units typically
// execute only FMA instructions or block matrix multiplications. For example,
// the transcendental activation function in a neural network would
// temporarily expand the 16-bit number to Float32.
//
// The AMX coprocessor can execute 16-bit instructions at twice the rate of
// 32-bit instructions. This low-precision capability would be extremely
// useful, reaching 50-100% performance of the GPU. However, it is not
// accessible through any public API. In addition, low-precision AI
// inference is easy to code on the GPU. For all practical purposes, the CPU
// is restricted to 32-bit types.
//
// This investigation should answer some questions about mixed precision on
// Apple GPUs.
// - Does the optimal block size change when the precision decreases from
// Float32 to Float16?
// - How well does the MSL built-in type for BFloat16 perform, compared to
// Float16?
// - How fast are the MPSMatrix and MPSGraph implementations of low-precision
// matrix multiplication?
// - Can MPS achieve double the performance by running two different
// matrix multiplications simultaneously, each with a different precision?
// - Can a kernel be engineered to exploit the FP16 and FP32 units
// simultaneously?
//
// ========================================================================== //
// Methods
// ========================================================================== //
//
// No async copies.
// 32 threads/threadgroup.
//
// ========================================================================== //
// Results (Raw Data)
// - Does the optimal block size change when the precision decreases from
// Float32 to Float16?
// - How well does the MSL built-in type for BFloat16 perform, compared to
// Float16?
// ========================================================================== //
//
// A, B, and C are Float32.
//
// M1 Max
// - problemSize = 256 | 913 GFLOPS (32x32x8)
// - problemSize = 384 | 2931 GFLOPS (32x32x8)
// - problemSize = 512 | 5342 GFLOPS (32x32x8)
// - problemSize = 640 | 5463 GFLOPS (32x32x8)
// - problemSize = 768 | 6160 GFLOPS (48x48x8)
// - problemSize = 896 | 6643 GFLOPS (48x48x8)
// - problemSize = 1024 | 7596 GFLOPS (48x48x8)
// - problemSize = 1152 | 7676 GFLOPS (48x48x8)
// - problemSize = 1280 | 7712 GFLOPS (48x48x8)
// - problemSize = 1408 | 7747 GFLOPS (48x48x8)
// - problemSize = 1536 | 8392 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 1195 GFLOPS (32x32x8)
// - problemSize = 384 | 1729 GFLOPS (32x32x8)
// - problemSize = 512 | 2549 GFLOPS (32x32x8)
// - problemSize = 640 | 2983 GFLOPS (32x32x8)
// - problemSize = 768 | 3036 GFLOPS (32x32x8)
// - problemSize = 896 | 3044 GFLOPS (32x32x8)
// - problemSize = 1024 | 3074 GFLOPS (32x32x8)
// - problemSize = 1152 | 3123 GFLOPS (32x32x8)
// - problemSize = 1280 | 3134 GFLOPS (32x32x8)
// - problemSize = 1408 | 3167 GFLOPS (32x32x8)
// - problemSize = 1536 | 3174 GFLOPS (32x32x8)
//
// A, B, and C are Float16.
//
// M1 Max
// - problemSize = 256 | 1390 GFLOPS (32x32x8)
// - problemSize = 384 | 3047 GFLOPS (32x32x8)
// - problemSize = 512 | 5578 GFLOPS (32x32x8)
// - problemSize = 640 | 6566 GFLOPS (32x32x8)
// - problemSize = 768 | 6022 GFLOPS (32x32x8)
// - problemSize = 896 | 7255 GFLOPS (48x48x8)
// - problemSize = 1024 | 7893 GFLOPS (48x48x8)
// - problemSize = 1152 | 8905 GFLOPS (48x48x8)
// - problemSize = 1280 | 8639 GFLOPS (48x48x8)
// - problemSize = 1408 | 8099 GFLOPS (48x48x8)
// - problemSize = 1536 | 9310 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 857 GFLOPS (32x32x8)
// - problemSize = 384 | 1912 GFLOPS (32x32x8)
// - problemSize = 512 | 2688 GFLOPS (32x32x8)
// - problemSize = 640 | 3306 GFLOPS (32x32x8)
// - problemSize = 768 | 3347 GFLOPS (32x32x8)
// - problemSize = 896 | 3383 GFLOPS (32x32x8)
// - problemSize = 1024 | 3428 GFLOPS (32x32x8)
// - problemSize = 1152 | 3471 GFLOPS (32x32x8)
// - problemSize = 1280 | 3491 GFLOPS (32x32x8)
// - problemSize = 1408 | 3517 GFLOPS (32x32x8)
// - problemSize = 1536 | 3532 GFLOPS (32x32x8)
//
// A, B, and C are BFloat16, through the MSL keyword 'bfloat'.
//
// M1 Max
// - problemSize = 256 | 604 GFLOPS (32x32x8)
// - problemSize = 384 | 1992 GFLOPS (32x32x8)
// - problemSize = 512 | 2822 GFLOPS (32x32x8)
// - problemSize = 640 | 3311 GFLOPS (32x32x8)
// - problemSize = 768 | 3879 GFLOPS (32x32x8)
// - problemSize = 896 | 3776 GFLOPS (32x32x8)
// - problemSize = 1024 | 4127 GFLOPS (32x32x8)
// - problemSize = 1152 | 3822 GFLOPS (32x32x8)
// - problemSize = 1280 | 4141 GFLOPS (32x32x8)
// - problemSize = 1408 | 4130 GFLOPS (32x32x8)
// - problemSize = 1536 | 4258 GFLOPS (32x32x8)
//
// M4
// - problemSize = 256 | 1171 GFLOPS (32x32x8)
// - problemSize = 384 | 1709 GFLOPS (32x32x8)
// - problemSize = 512 | 2736 GFLOPS (32x32x8)
// - problemSize = 640 | 2887 GFLOPS (32x32x8)
// - problemSize = 768 | 2945 GFLOPS (32x32x8)
// - problemSize = 896 | 2960 GFLOPS (32x32x8)
// - problemSize = 1024 | 3006 GFLOPS (32x32x8)
// - problemSize = 1152 | 3035 GFLOPS (32x32x8)
// - problemSize = 1280 | 3054 GFLOPS (32x32x8)
// - problemSize = 1408 | 3081 GFLOPS (32x32x8)
// - problemSize = 1536 | 3086 GFLOPS (32x32x8)
//
// A and B are BFloat16, C is Float32.
//
// M1 Max
// - problemSize = 256 | 675 GFLOPS (32x32x8)
// - problemSize = 384 | 2719 GFLOPS (32x32x8)
// - problemSize = 512 | 4840 GFLOPS (32x32x8)
// - problemSize = 640 | 5954 GFLOPS (32x32x8)
// - problemSize = 768 | 6260 GFLOPS (32x32x8)
// - problemSize = 896 | 6805 GFLOPS (48x48x8)
// - problemSize = 1024 | 6590 GFLOPS (48x48x8)
// - problemSize = 1152 | 7707 GFLOPS (48x48x8)
// - problemSize = 1280 | 7469 GFLOPS (48x48x8)
// - problemSize = 1408 | 7296 GFLOPS (48x48x8)
// - problemSize = 1536 | 8247 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 1268 GFLOPS (32x32x8)
// - problemSize = 384 | 1874 GFLOPS (32x32x8)
// - problemSize = 512 | 2986 GFLOPS (32x32x8)
// - problemSize = 640 | 3254 GFLOPS (32x32x8)
// - problemSize = 768 | 3310 GFLOPS (32x32x8)
// - problemSize = 896 | 3333 GFLOPS (32x32x8)
// - problemSize = 1024 | 3348 GFLOPS (32x32x8)
// - problemSize = 1152 | 3416 GFLOPS (32x32x8)
// - problemSize = 1280 | 3438 GFLOPS (32x32x8)
// - problemSize = 1408 | 3463 GFLOPS (32x32x8)
// - problemSize = 1536 | 3477 GFLOPS (32x32x8)
//
// A and B are BFloat16, C is Float16.
//
// M1 Max
// - problemSize = 256 | 1120 GFLOPS (32x32x8)
// - problemSize = 384 | 2387 GFLOPS (32x32x8)
// - problemSize = 512 | 4629 GFLOPS (32x32x8)
// - problemSize = 640 | 5736 GFLOPS (32x32x8)
// - problemSize = 768 | 6455 GFLOPS (32x32x8)
// - problemSize = 896 | 6515 GFLOPS (48x48x8)
// - problemSize = 1024 | 6643 GFLOPS (48x48x8)
// - problemSize = 1152 | 7459 GFLOPS (48x48x8)
// - problemSize = 1280 | 6646 GFLOPS (48x48x8)
// - problemSize = 1408 | 6756 GFLOPS (48x48x8)
// - problemSize = 1536 | 7596 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 1176 GFLOPS (32x32x8)
// - problemSize = 384 | 1895 GFLOPS (32x32x8)
// - problemSize = 512 | 2723 GFLOPS (32x32x8)
// - problemSize = 640 | 2917 GFLOPS (32x32x8)
// - problemSize = 768 | 2950 GFLOPS (32x32x8)
// - problemSize = 896 | 2969 GFLOPS (32x32x8)
// - problemSize = 1024 | 3002 GFLOPS (32x32x8)
// - problemSize = 1152 | 3033 GFLOPS (32x32x8)
// - problemSize = 1280 | 3054 GFLOPS (32x32x8)
// - problemSize = 1408 | 3079 GFLOPS (32x32x8)
// - problemSize = 1536 | 3088 GFLOPS (32x32x8)
//
// A and B are BFloat16 through MSL keyword.
// C is encoded from Float32 in registers to BFloat16 in RAM.
//
// M1 Max
// - problemSize = 256 | 825 GFLOPS (32x32x8)
// - problemSize = 384 | 2726 GFLOPS (32x32x8)
// - problemSize = 512 | 4884 GFLOPS (32x32x8)
// - problemSize = 640 | 5860 GFLOPS (32x32x8)
// - problemSize = 768 | 6253 GFLOPS (32x32x8)
// - problemSize = 896 | 6823 GFLOPS (48x48x8)
// - problemSize = 1024 | 6613 GFLOPS (48x48x8)
// - problemSize = 1152 | 7715 GFLOPS (48x48x8)
// - problemSize = 1280 | 7458 GFLOPS (48x48x8)
// - problemSize = 1408 | 7287 GFLOPS (48x48x8)
// - problemSize = 1536 | 8263 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 1274 GFLOPS (32x32x8)
// - problemSize = 384 | 1885 GFLOPS (32x32x8)
// - problemSize = 512 | 3004 GFLOPS (32x32x8)
// - problemSize = 640 | 3252 GFLOPS (32x32x8)
// - problemSize = 768 | 3290 GFLOPS (32x32x8)
// - problemSize = 896 | 3314 GFLOPS (32x32x8)
// - problemSize = 1024 | 3338 GFLOPS (32x32x8)
// - problemSize = 1152 | 3387 GFLOPS (32x32x8)
// - problemSize = 1280 | 3414 GFLOPS (32x32x8)
// - problemSize = 1408 | 3438 GFLOPS (32x32x8)
// - problemSize = 1536 | 3452 GFLOPS (32x32x8)
//
// A and B are decoded from BFloat16 in RAM to Float32 in registers.
// C is encoded from Float32 in registers to BFloat16 in RAM.
//
// M1 Max
// - problemSize = 256 | 862 GFLOPS (32x32x8)
// - problemSize = 384 | 2807 GFLOPS (32x32x8)
// - problemSize = 512 | 4082 GFLOPS (32x32x8)
// - problemSize = 640 | 6434 GFLOPS (32x32x8)
// - problemSize = 768 | 6206 GFLOPS (32x32x8)
// - problemSize = 896 | 7056 GFLOPS (48x48x8)
// - problemSize = 1024 | 7487 GFLOPS (48x48x8)
// - problemSize = 1152 | 7769 GFLOPS (48x48x8)
// - problemSize = 1280 | 7686 GFLOPS (48x48x8)
// - problemSize = 1408 | 7760 GFLOPS (48x48x8)
// - problemSize = 1536 | 8684 GFLOPS (48x48x8)
//
// M4
// - problemSize = 256 | 1240 GFLOPS (32x32x8)
// - problemSize = 384 | 1878 GFLOPS (32x32x8)
// - problemSize = 512 | 2695 GFLOPS (32x32x8)
// - problemSize = 640 | 2910 GFLOPS (32x32x8)
// - problemSize = 768 | 2939 GFLOPS (32x32x8)
// - problemSize = 896 | 2939 GFLOPS (32x32x8)
// - problemSize = 1024 | 2984 GFLOPS (32x32x8)
// - problemSize = 1152 | 3006 GFLOPS (32x32x8)
// - problemSize = 1280 | 3022 GFLOPS (32x32x8)
// - problemSize = 1408 | 3047 GFLOPS (32x32x8)
// - problemSize = 1536 | 3060 GFLOPS (32x32x8)
//
// ========================================================================== //
// Results (Summary)
// - Does the optimal block size change when the precision decreases from
// Float32 to Float16?
// - How well does the MSL built-in type for BFloat16 perform, compared to
// Float16?
// ========================================================================== //
//
// For practical arithmetic intensities (N < 1500), the optimal block size
// size is always 48x48 on M1, 32x32 on M4. This rule holds regardless of the
// precision, or whether async copies are being used.
//
// The statistics were augmented by a few more combinations of input/output
// precision. For brevity, the previous section does not include the raw data
// for these combinations.
//
// Accumulate in FP16.
//
// M1 Max
// - A = FP16, B = FP16, C = FP16 | 9310 GFLOPS
// - A = BF16, B = BF16, C = FP16 | 7596 GFLOPS
// - A = FP32, B = FP32, C = FP16 | 7869 GFLOPS
//
// M4
// - A = FP16, B = FP16, C = FP16 | 3532 GFLOPS
// - A = BF16, B = BF16, C = FP16 | 3088 GFLOPS
// - A = FP32, B = FP32, C = FP16 | 2914 GFLOPS
//
// Accumulate in BF16.
//
// M1 Max
// - A = BF16, B = BF16, C = BF16 | 4258 GFLOPS (32x32x8)
// - A = BF16, B = BF16, C = eBF16 | 8263 GFLOPS
// - A = eBF16, B = eBF16, C = eBF16 | 8684 GFLOPS
//
// M4
// - A = BF16, B = BF16, C = BF16 | 3086 GFLOPS
// - A = BF16, B = BF16, C = eBF16 | 3452 GFLOPS
// - A = eBF16, B = eBF16, C = eBF16 | 3060 GFLOPS
//
// Accumulate in FP32.
//
// M1 Max
// - A = FP16, B = FP16, C = FP32 | 8952 GFLOPS
// - A = BF16, B = BF16, C = FP32 | 8247 GFLOPS
// - A = eBF16, B = eBF16, C = FP32 | 8675 GFLOPS
// - A = FP32, B = FP32, C = FP32 | 8392 GFLOPS
//
// M4
// - A = FP16, B = FP16, C = FP32 | 3477 GFLOPS
// - A = BF16, B = BF16, C = FP32 | 3477 GFLOPS
// - A = eBF16, B = eBF16, C = FP32 | 3080 GFLOPS
// - A = FP32, B = FP32, C = FP32 | 3174 GFLOPS
//
// ========================================================================== //
// Results (Raw Data)
// - How fast are the MPSMatrix and MPSGraph implementations of low-precision
// matrix multiplication?
// - Can MPS achieve double the performance by running two different
// matrix multiplications simultaneously, each with a different precision?
// ========================================================================== //
//
// For brevity, this section only shows kernels where all 3 operands have the
// same precision.
//
// MPSMatrix, Float32.
//
// M1 Max
// - problemSize = 256 | 484 GFLOPS
// - problemSize = 384 | 1587 GFLOPS
// - problemSize = 512 | 3475 GFLOPS
// - problemSize = 640 | 5830 GFLOPS
// - problemSize = 768 | 5936 GFLOPS
// - problemSize = 896 | 6738 GFLOPS
// - problemSize = 1024 | 8050 GFLOPS
// - problemSize = 1152 | 7345 GFLOPS
// - problemSize = 1280 | 7444 GFLOPS
// - problemSize = 1408 | 7758 GFLOPS
// - problemSize = 1536 | 8055 GFLOPS
//
// M4
// - problemSize = 256 | 460 GFLOPS
// - problemSize = 384 | 1036 GFLOPS
// - problemSize = 512 | 1569 GFLOPS
// - problemSize = 640 | 3028 GFLOPS
// - problemSize = 768 | 3087 GFLOPS
// - problemSize = 896 | 3086 GFLOPS
// - problemSize = 1024 | 3125 GFLOPS
// - problemSize = 1152 | 3152 GFLOPS
// - problemSize = 1280 | 3134 GFLOPS
// - problemSize = 1408 | 3150 GFLOPS
// - problemSize = 1536 | 3129 GFLOPS
//
// MPSMatrix, Float16.
//
// M1 Max
// - problemSize = 256 | 509 GFLOPS
// - problemSize = 384 | 2285 GFLOPS
// - problemSize = 512 | 3260 GFLOPS
// - problemSize = 640 | 5398 GFLOPS
// - problemSize = 768 | 5647 GFLOPS
// - problemSize = 896 | 6231 GFLOPS
// - problemSize = 1024 | 7364 GFLOPS
// - problemSize = 1152 | 6774 GFLOPS
// - problemSize = 1280 | 6983 GFLOPS
// - problemSize = 1408 | 7122 GFLOPS
// - problemSize = 1536 | 7487 GFLOPS
//
// M4
// - problemSize = 256 | 601 GFLOPS
// - problemSize = 384 | 1175 GFLOPS
// - problemSize = 512 | 2030 GFLOPS
// - problemSize = 640 | 3349 GFLOPS
// - problemSize = 768 | 3400 GFLOPS
// - problemSize = 896 | 3444 GFLOPS
// - problemSize = 1024 | 3479 GFLOPS
// - problemSize = 1152 | 3489 GFLOPS
// - problemSize = 1280 | 3504 GFLOPS
// - problemSize = 1408 | 3503 GFLOPS
// - problemSize = 1536 | 3496 GFLOPS
//
// MPSMatrix does not support BFloat16.
//
// MPSGraph, Float32.
//
// M1 Max
// - problemSize = 256 | 439 GFLOPS
// - problemSize = 384 | 878 GFLOPS
// - problemSize = 512 | 2727 GFLOPS
// - problemSize = 640 | 4549 GFLOPS
// - problemSize = 768 | 5521 GFLOPS
// - problemSize = 896 | 6316 GFLOPS
// - problemSize = 1024 | 7000 GFLOPS
// - problemSize = 1152 | 7320 GFLOPS
// - problemSize = 1280 | 7596 GFLOPS
// - problemSize = 1408 | 7597 GFLOPS
// - problemSize = 1536 | 8274 GFLOPS
//
// M4
// - problemSize = 256 | 590 GFLOPS
// - problemSize = 384 | 1105 GFLOPS
// - problemSize = 512 | 2051 GFLOPS
// - problemSize = 640 | 2798 GFLOPS
// - problemSize = 768 | 2937 GFLOPS
// - problemSize = 896 | 2986 GFLOPS
// - problemSize = 1024 | 3048 GFLOPS
// - problemSize = 1152 | 3100 GFLOPS
// - problemSize = 1280 | 3083 GFLOPS
// - problemSize = 1408 | 3108 GFLOPS
// - problemSize = 1536 | 3100 GFLOPS
//
// MPSGraph, Float16.
//
// M1 Max
// - problemSize = 256 | 399 GFLOPS
// - problemSize = 384 | 1254 GFLOPS
// - problemSize = 512 | 2951 GFLOPS
// - problemSize = 640 | 4630 GFLOPS
// - problemSize = 768 | 5858 GFLOPS
// - problemSize = 896 | 6606 GFLOPS
// - problemSize = 1024 | 7051 GFLOPS
// - problemSize = 1152 | 7542 GFLOPS
// - problemSize = 1280 | 7942 GFLOPS
// - problemSize = 1408 | 8099 GFLOPS
// - problemSize = 1536 | 8418 GFLOPS
//
// M4, Level 0
// - problemSize = 256 | 710 GFLOPS
// - problemSize = 384 | 1462 GFLOPS
// - problemSize = 512 | 2337 GFLOPS
// - problemSize = 640 | 2578 GFLOPS
// - problemSize = 768 | 2708 GFLOPS
// - problemSize = 896 | 2736 GFLOPS
// - problemSize = 1024 | 2820 GFLOPS
// - problemSize = 1152 | 2897 GFLOPS
// - problemSize = 1280 | 2954 GFLOPS
// - problemSize = 1408 | 3006 GFLOPS
// - problemSize = 1536 | 3041 GFLOPS
//
// M4, Level 1
// - problemSize = 256 | 735 GFLOPS
// - problemSize = 384 | 1488 GFLOPS
// - problemSize = 512 | 2332 GFLOPS
// - problemSize = 640 | 2586 GFLOPS
// - problemSize = 768 | 2753 GFLOPS
// - problemSize = 896 | 2734 GFLOPS
// - problemSize = 1024 | 2826 GFLOPS
// - problemSize = 1152 | 3843 GFLOPS
// - problemSize = 1280 | 4404 GFLOPS
// - problemSize = 1408 | 4786 GFLOPS
// - problemSize = 1536 | 5204 GFLOPS
// - problemSize = 2048 | 5916 GFLOPS
// - problemSize = 3072 | 5921 GFLOPS
// - problemSize = 3584 | 6573 GFLOPS
// - problemSize = 4096 | 6631 GFLOPS
// - problemSize = 4608 | 3883 GFLOPS
// - problemSize = 5120 | 3755 GFLOPS
//
// MPSGraph, BFloat16.
//
// M1 Max
// - problemSize = 256 | 325 GFLOPS
// - problemSize = 384 | 1132 GFLOPS
// - problemSize = 512 | 1987 GFLOPS
// - problemSize = 640 | 2352 GFLOPS
// - problemSize = 768 | 3311 GFLOPS
// - problemSize = 896 | 2996 GFLOPS
// - problemSize = 1024 | 3325 GFLOPS
// - problemSize = 1152 | 3691 GFLOPS
// - problemSize = 1280 | 3908 GFLOPS
// - problemSize = 1408 | 3937 GFLOPS
// - problemSize = 1536 | 4047 GFLOPS
//
// M4
// - problemSize = 256 | 774 GFLOPS
// - problemSize = 384 | 1419 GFLOPS
// - problemSize = 512 | 2302 GFLOPS
// - problemSize = 640 | 2588 GFLOPS
// - problemSize = 768 | 2726 GFLOPS
// - problemSize = 896 | 2755 GFLOPS
// - problemSize = 1024 | 2836 GFLOPS
// - problemSize = 1152 | 2903 GFLOPS
// - problemSize = 1280 | 2950 GFLOPS
// - problemSize = 1408 | 2999 GFLOPS
// - problemSize = 1536 | 3041 GFLOPS
//
// MPSGraph, simultaneous Float16 and Float32.
//
// M1 Max
// - problemSize = 256 | 630 GFLOPS
// - problemSize = 384 | 1971 GFLOPS
// - problemSize = 512 | 4434 GFLOPS
// - problemSize = 640 | 5283 GFLOPS
// - problemSize = 768 | 6228 GFLOPS
// - problemSize = 896 | 6859 GFLOPS
// - problemSize = 1024 | 7330 GFLOPS
// - problemSize = 1152 | 7611 GFLOPS
// - problemSize = 1280 | 7936 GFLOPS
// - problemSize = 1408 | 7890 GFLOPS
// - problemSize = 1536 | 8331 GFLOPS
//
// M4, Level 0
// - problemSize = 256 | 911 GFLOPS
// - problemSize = 384 | 1820 GFLOPS
// - problemSize = 512 | 2501 GFLOPS
// - problemSize = 640 | 2629 GFLOPS
// - problemSize = 768 | 2721 GFLOPS
// - problemSize = 896 | 2856 GFLOPS
// - problemSize = 1024 | 2940 GFLOPS
// - problemSize = 1152 | 2994 GFLOPS
// - problemSize = 1280 | 3031 GFLOPS
// - problemSize = 1408 | 3065 GFLOPS
// - problemSize = 1536 | 3074 GFLOPS
//
// M4, Level 1
// - problemSize = 256 | 952 GFLOPS
// - problemSize = 384 | 1935 GFLOPS
// - problemSize = 512 | 2515 GFLOPS
// - problemSize = 640 | 2636 GFLOPS
// - problemSize = 768 | 2749 GFLOPS
// - problemSize = 896 | 2859 GFLOPS
// - problemSize = 1024 | 2956 GFLOPS
// - problemSize = 1152 | 3461 GFLOPS
// - problemSize = 1280 | 3584 GFLOPS
// - problemSize = 1408 | 3759 GFLOPS
// - problemSize = 1536 | 3834 GFLOPS
// - problemSize = 2048 | 4100 GFLOPS
// - problemSize = 3072 | 4114 GFLOPS
// - problemSize = 3584 | 4238 GFLOPS
// - problemSize = 4096 | 4261 GFLOPS
// - problemSize = 4608 | 3374 GFLOPS
// - problemSize = 5120 | 3276 GFLOPS
//
// MPSGraph, simultaneous Float16 and BFloat16.
//
// M1 Max
// - problemSize = 256 | 536 GFLOPS
// - problemSize = 384 | 1755 GFLOPS
// - problemSize = 512 | 2812 GFLOPS
// - problemSize = 640 | 3322 GFLOPS
// - problemSize = 768 | 4420 GFLOPS
// - problemSize = 896 | 4215 GFLOPS
// - problemSize = 1024 | 4609 GFLOPS
// - problemSize = 1152 | 5028 GFLOPS
// - problemSize = 1280 | 5294 GFLOPS
// - problemSize = 1408 | 5338 GFLOPS
// - problemSize = 1536 | 5472 GFLOPS
//
// M4, Level 0
// - problemSize = 256 | 933 GFLOPS
// - problemSize = 384 | 1899 GFLOPS
// - problemSize = 512 | 2429 GFLOPS
// - problemSize = 640 | 2540 GFLOPS
// - problemSize = 768 | 2639 GFLOPS
// - problemSize = 896 | 2734 GFLOPS
// - problemSize = 1024 | 2834 GFLOPS
// - problemSize = 1152 | 2909 GFLOPS
// - problemSize = 1280 | 2963 GFLOPS
// - problemSize = 1408 | 3013 GFLOPS
// - problemSize = 1536 | 3053 GFLOPS
//
// M4, Level 1
// - problemSize = 256 | 949 GFLOPS
// - problemSize = 384 | 1976 GFLOPS
// - problemSize = 512 | 2453 GFLOPS
// - problemSize = 640 | 2552 GFLOPS
// - problemSize = 768 | 2629 GFLOPS
// - problemSize = 896 | 2736 GFLOPS
// - problemSize = 1024 | 2853 GFLOPS
// - problemSize = 1152 | 3288 GFLOPS
// - problemSize = 1280 | 3484 GFLOPS
// - problemSize = 1408 | 3658 GFLOPS
// - problemSize = 1536 | 3647 GFLOPS
// - problemSize = 2048 | 4126 GFLOPS
// - problemSize = 3072 | 4252 GFLOPS
// - problemSize = 3584 | 4429 GFLOPS
// - problemSize = 4096 | 4476 GFLOPS
// - problemSize = 4608 | 3624 GFLOPS
// - problemSize = 5120 | 3568 GFLOPS
//
// MPSGraph, simultaneous BFloat16 and Float32.
//
// M1 Max
// - problemSize = 256 | 547 GFLOPS
// - problemSize = 384 | 1764 GFLOPS
// - problemSize = 512 | 2966 GFLOPS
// - problemSize = 640 | 3346 GFLOPS
// - problemSize = 768 | 4448 GFLOPS
// - problemSize = 896 | 4204 GFLOPS
// - problemSize = 1024 | 4613 GFLOPS
// - problemSize = 1152 | 4999 GFLOPS
// - problemSize = 1280 | 5222 GFLOPS
// - problemSize = 1408 | 5223 GFLOPS
// - problemSize = 1536 | 5438 GFLOPS
//
// M4, Level 0
// - problemSize = 256 | 936 GFLOPS
// - problemSize = 384 | 1886 GFLOPS
// - problemSize = 512 | 2499 GFLOPS
// - problemSize = 640 | 2632 GFLOPS
// - problemSize = 768 | 2749 GFLOPS
// - problemSize = 896 | 2852 GFLOPS
// - problemSize = 1024 | 2942 GFLOPS
// - problemSize = 1152 | 2999 GFLOPS
// - problemSize = 1280 | 3018 GFLOPS
// - problemSize = 1408 | 3062 GFLOPS
// - problemSize = 1536 | 3076 GFLOPS
//
// M4, Level 1
// - problemSize = 256 | 1017 GFLOPS
// - problemSize = 384 | 1905 GFLOPS
// - problemSize = 512 | 2534 GFLOPS
// - problemSize = 640 | 2642 GFLOPS
// - problemSize = 768 | 2742 GFLOPS
// - problemSize = 896 | 2861 GFLOPS
// - problemSize = 1024 | 2955 GFLOPS
// - problemSize = 1152 | 3003 GFLOPS
// - problemSize = 1280 | 3021 GFLOPS
// - problemSize = 1408 | 3059 GFLOPS
// - problemSize = 1536 | 3075 GFLOPS
//
// ========================================================================== //
// Results (Raw Data)
// - Can a kernel be engineered to exploit the FP16 and FP32 units
// simultaneously?
// ========================================================================== //
//
// I tried running a kernel with two execution paths: one was a standard FP32
// GEMM, the other a standard FP16 GEMM. This did not provide any speedup on
// M4. I tried again with various mixtures of precisions, including BF16. Still
// no speedup. I am starting to doubt Apple's claim about 2x performance for
// workloads with both FP16 and FP32.
//
// As a final attempt, I wrote a kernel where the accumulator is half-FP16,
// half-FP32. All operands are FP16 in memory. The FP32 part of the accumulator
// decompresses some FP16 values in memory to FP32 before multiplication.
//
// M1 Max
// - problemSize = 256 | 816 GFLOPS
// - problemSize = 384 | 2798 GFLOPS
// - problemSize = 512 | 5289 GFLOPS
// - problemSize = 640 | 6374 GFLOPS
// - problemSize = 768 | 6201 GFLOPS
// - problemSize = 896 | 6494 GFLOPS
// - problemSize = 1024 | 7790 GFLOPS
// - problemSize = 1152 | 5764 GFLOPS
// - problemSize = 1280 | 6813 GFLOPS
// - problemSize = 1408 | 7409 GFLOPS
// - problemSize = 1536 | 7225 GFLOPS
//
// M4
// - problemSize = 256 | 1224 GFLOPS
// - problemSize = 384 | 1766 GFLOPS
// - problemSize = 512 | 2617 GFLOPS
// - problemSize = 640 | 3057 GFLOPS
// - problemSize = 768 | 3093 GFLOPS
// - problemSize = 896 | 3135 GFLOPS
// - problemSize = 1024 | 3141 GFLOPS
// - problemSize = 1152 | 3188 GFLOPS
// - problemSize = 1280 | 3203 GFLOPS
// - problemSize = 1408 | 3229 GFLOPS
// - problemSize = 1536 | 3243 GFLOPS
//
// This still does not achieve the 2x performance boost I had theorized. One
// very last stretch: make the LHS be stored as both FP16 and FP32 in memory.
// When data needs to be accumulated in FP32, it will read values as FP32 with
// no decompression. That should ensure the compiler doesn't convert desired
// FP32 instructions into undesired FP16 instructions.
//
// M1 Max
// - problemSize = 256 | 863 GFLOPS
// - problemSize = 384 | 2812 GFLOPS
// - problemSize = 512 | 5146 GFLOPS
// - problemSize = 640 | 6263 GFLOPS
// - problemSize = 768 | 6879 GFLOPS
// - problemSize = 896 | 6488 GFLOPS
// - problemSize = 1024 | 7870 GFLOPS
// - problemSize = 1152 | 5959 GFLOPS
// - problemSize = 1280 | 6883 GFLOPS
// - problemSize = 1408 | 7155 GFLOPS
// - problemSize = 1536 | 7195 GFLOPS
//
// M4
// - problemSize = 256 | 792 GFLOPS
// - problemSize = 384 | 1319 GFLOPS
// - problemSize = 512 | 2440 GFLOPS
// - problemSize = 640 | 3019 GFLOPS
// - problemSize = 768 | 3054 GFLOPS
// - problemSize = 896 | 3060 GFLOPS
// - problemSize = 1024 | 3096 GFLOPS
// - problemSize = 1152 | 3148 GFLOPS
// - problemSize = 1280 | 3150 GFLOPS
// - problemSize = 1408 | 3184 GFLOPS
// - problemSize = 1536 | 3189 GFLOPS
//
// ========================================================================== //
// Discussion
// ========================================================================== //
//
// In both architectures, FP16 achieves better performance than FP32/BF16 no
// matter the calculation. FP32 can get very close, provided the operands are
// streamed from L1 as 16-bit types (to minimize the bandwidth bottleneck).
// Lower-precision data types do not improve performance much by increasing
// occupancy. Otherwise, we would see 16-bit types being optimal at larger
// block sizes than 32-bit types.
//
// MPS is also under-optimized for 16-bit data types on M3/M4. MPSGraph does
// not exceed 3000 GFLOPS on the M4 for either FP16 or BF16. It does not exceed
// 3100 GFLOPS for FP32. MPSMatrix does have a performance improvement for
// half, but it crashes on bfloat.
//
// GFLOPS | FP16 | BF16 | FP32 |
// ------------- | ----- | ----- | ----- |
// MPSGraph | 3000 | 3000 | 3100 |
// MPSMatrix | 3500 | error | 3100 |
// Custom Kernel | 3500 | 3500* | 3100 |
//
// *Only achieved when the accumulator is FP32. Otherwise, BF16 performance is
// 3000 GFLOPS.
//
// ===== The attempt to achieve 7000 GFLOPS on M4 =====
//
// This did not work. The M3/M4 architecture is basically the same as M1/M2, in
// terms of ALU throughput. The BF16 support is mediocre, only implemented as
// an optimization to BF16 -> FP32 decoding. Furthermore, BF16 only reaches
// maximum performance when the accumulator is FP32. I can get similar
// performance by just emulating BF16 (within a factor of 1.1x).
//
// I did achieve 6500 GFLOPS at MPSGraph optimization level 1. Only when all of
// the data was Float16. That is because of the well-known delegation of
// galactic FP16 matrix multiplies to the neural engine. Not because it is
// getting ~7000 GFLOPS out of the GPU.
//
// ===== MPS BFloat16 performance is dismal on M1 =====
//
// GFLOPS | FP16 | BF16 | FP32 |
// ------------- | ----- | ----- | ----- |
// MPSGraph | 8100 | 4000 | 8300 |
// MPSMatrix | 7500 | error | 8100 |
// Custom Kernel | 9300 | 8700* | 8400 |
//
// *Only achieved when the A/B inputs are decoded into Float32 registers
// before multiplying. This is technically BF15, as rounding error makes the
// final bit end up as garbage.
func runApplication() {
print("Hello, console.")
profileProblemSize(64)
#if false
profileProblemSize(256)
profileProblemSize(384)
profileProblemSize(512)
profileProblemSize(640)
profileProblemSize(768)
profileProblemSize(896)
profileProblemSize(1024)
profileProblemSize(1152)
profileProblemSize(1280)
profileProblemSize(1408)
profileProblemSize(1536)
#endif
#if false
profileProblemSize(2048)
profileProblemSize(3072)
profileProblemSize(3072 + 512)
profileProblemSize(4096)
profileProblemSize(4096 + 512)
profileProblemSize(5120)
#endif
}
// MARK: - Shader Sources
let metalSimdgroupEvent: String = """
// -*- Metal -*-
//===-- metal_simdgroup_event ---------------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_EVENT
#define __METAL_SIMDGROUP_EVENT
#if !defined(__HAVE_SIMDGROUP_FUTURE__)
// Invoking the generation of LLVM bitcode for async copies.
//
// %struct._simdgroup_event_t = type opaque
//
struct _simdgroup_event_t;
// Invoking the generation of LLVM bitcode for async copies.
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, threadgroup void *, const device void *, ulong)
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, device void *, const threadgroup void *, ulong)
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
// i64, i64, i8 addrspace(3)* nocapture writeonly,
// i64, i64, <2 x i64>, i8 addrspace(1)* nocapture readonly,
// i64, i64, <2 x i64>, <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong, threadgroup void *,
ulong, ulong, ulong2, const device void *,
ulong, ulong, ulong2, long2, int)
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
// i64, i64, i8 addrspace(1)* nocapture writeonly,
// i64, i64, <2 x i64>, i8 addrspace(3)* nocapture readonly,
// i64, i64, <2 x i64>, <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong, device void *,
ulong, ulong, ulong2, const threadgroup void *,
ulong, ulong, ulong2, long2, int)
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: convergent nounwind
// declare void
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
// local_unnamed_addr #3
//
void __metal_wait_simdgroup_events(
int, thread _simdgroup_event_t**)
__asm("air.wait_simdgroup_events");
#endif
#pragma METAL internals : enable
namespace metal
{
#if !defined(__HAVE_SIMDGROUP_FUTURE__)
enum class simdgroup_async_copy_clamp_mode {
clamp_to_zero = 0,
clamp_to_edge = 1
};
#endif
struct simdgroup_event {
METAL_FUNC simdgroup_event() thread {}
template <typename T>
METAL_FUNC void async_copy(threadgroup T *dst, const device T *src, ulong n_elements) thread {
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), reinterpret_cast<const device void *>(src), n_elements);
}
template <typename T>
METAL_FUNC void async_copy(device T *dst, const threadgroup T *src, ulong n_elements) thread {
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), reinterpret_cast<const threadgroup void *>(src), n_elements);
}
template <typename T>
METAL_FUNC void async_copy(threadgroup T *dst, ushort dst_elements_per_row, ushort2 dst_tile_dimensions, const device T *src, uint src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), ushort(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const device void *>(src), uint(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), static_cast<int>(clamp_mode));
}
template <typename T>
METAL_FUNC void async_copy(device T *dst, uint dst_elements_per_row, ushort2 dst_tile_dimensions, const threadgroup T *src, ushort src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), uint(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const threadgroup void *>(src), ushort(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), 0);
}
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
#if defined(__HAVE_SIMDGROUP_FUTURE__)
__metal_wait_simdgroup_events(count, reinterpret_cast<thread __metal_simdgroup_event_t*>(events));
#else
__metal_wait_simdgroup_events(count, reinterpret_cast<thread _simdgroup_event_t**>(events));
#endif
}
private:
// Invoking the generation of LLVM bitcode for async copies.
//
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
//
#if defined(__HAVE_SIMDGROUP_FUTURE__)
__metal_simdgroup_event_t event;
#else
thread _simdgroup_event_t* event;
#endif
};
} // namespace metal
#pragma METAL internals : disable
#endif // __METAL_SIMDGROUP_EVENT
"""
let metalSimdgroupMatrixStorage: String = """
// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE
// Contains C++ symbols accessible to a developer through automatic code
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
// Library for consistency with other Metal code.
#if defined(__HAVE_SIMDGROUP_MATRIX__)
#pragma METAL internals : enable
namespace metal
{
template <typename T>
struct simdgroup_matrix_storage {
typedef vec<T, 64> storage_type;
storage_type t;
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
return reinterpret_cast<thread vec<T, 2>*>(&t);
}
METAL_FUNC simdgroup_matrix_storage() thread = default;
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
*(this->thread_elements()) = thread_elements;
}
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
// https://patents.google.com/patent/US11256518B2
ushort lane_id = thread_index_in_simdgroup;
ushort quad_id = lane_id / 4;
constexpr ushort QUADRANT_SPAN_M = 4;
constexpr ushort THREADS_PER_QUADRANT = 8;
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
return ushort2(N_in_simd, M_in_simd);
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
// WARNING: All load and store functions assume the X dimension is divisible by 2.
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
}
}
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
}
}
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void load_bfloat(const device bfloat *src, uint elements_per_row, ushort2 matrix_origin) {
auto src_ptr = src + (matrix_origin.y * elements_per_row + matrix_origin.x);
reinterpret_cast<thread ushort4*>(thread_elements())->yw =
*reinterpret_cast<const device ushort2*>(src_ptr);
}
METAL_FUNC void load_half(const device half *src, uint elements_per_row, ushort2 matrix_origin) {
auto src_ptr = src + (matrix_origin.y * elements_per_row + matrix_origin.x);
*thread_elements() = vec<T, 2>(
*reinterpret_cast<const device half2*>(src_ptr));
}
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
}
}
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
}
}
METAL_FUNC void store_bfloat(device bfloat *dst, uint elements_per_row, ushort2 matrix_origin) {
auto dst_ptr = dst + (matrix_origin.y * elements_per_row + matrix_origin.x);
*reinterpret_cast<device ushort2*>(dst_ptr) =
reinterpret_cast<thread ushort4*>(thread_elements())->yw;
}
METAL_FUNC void store_half(device half *dst, uint elements_per_row, ushort2 matrix_origin) {
auto dst_ptr = dst + (matrix_origin.y * elements_per_row + matrix_origin.x);
*reinterpret_cast<device half2*>(dst_ptr) = half2(*thread_elements());
}
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
template <typename U, typename V>
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
if (!accumulate) {
*(thread_elements()) = vec<T, 2>(0);
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
};
} // namespace metal
#pragma METAL internals : disable
#endif
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
"""
let GEMM = """
//
// GEMM.metal
// MetalFlashAttention
//
// Created by Philip Turner on 6/23/23.
//
#include <metal_stdlib>
\(metalSimdgroupMatrixStorage)
using namespace metal;
// MARK: - Function Constants
// Dimensions of each matrix.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans [[function_constant(10)]];
constant bool B_trans [[function_constant(11)]];
constant uint A_leading_dim = A_trans ? M : K;
constant uint B_leading_dim = B_trans ? K : N;
constant ushort M_simd [[function_constant(200)]];
constant ushort N_simd [[function_constant(201)]];
// Elide work on the edge when matrix dimension < SRAM block dimension.
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
// MARK: - Utilities
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[M_simd][8]
return sram + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[8][N_simd]
return sram + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// C_sram[M_simd][N_simd]
return sram + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
device half *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start; m < m_end / 2; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store_half(C, N, origin);
}
}
}
template <typename T>
METAL_FUNC void store_accumulator_2(thread simdgroup_matrix_storage<T> *sram,
device half *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start / 2; m < m_end; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store(C, N, origin);
}
}
}
// MARK: - Kernels
kernel void hgemm(device half *A [[buffer(0)]],
device half *B [[buffer(1)]],
device half *C [[buffer(2)]],
device float *A32 [[buffer(3)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort lane_id [[thread_index_in_simdgroup]])
{
simdgroup_matrix_storage<float> A_sram_allocation[1024];
simdgroup_matrix_storage<float> B_sram_allocation[1024];
simdgroup_matrix_storage<float> C_sram_allocation[1024];
simdgroup_matrix_storage<half> A_sram_allocation_2[1024];
simdgroup_matrix_storage<half> C_sram_allocation_2[1024];
ushort2 offset_in_simd = simdgroup_matrix_storage<half>::offset(lane_id);
ushort2 A_offset(0, gid.y * M_simd);
ushort2 B_offset(gid.x * N_simd, 0);
A_offset += offset_in_simd;
B_offset += offset_in_simd;
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded / 2; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto C = C_sram(C_sram_allocation, ushort2(n, m));
*C = simdgroup_matrix_storage<float>(0);
}
}
#pragma clang loop unroll(full)
for (ushort m = M_padded / 2; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto C = C_sram(C_sram_allocation_2, ushort2(n, m));
*C = simdgroup_matrix_storage<half>(0);
}
}
for (uint k = 0; k < K; k += 8) {
auto A32_src = simdgroup_matrix_storage<float>::apply_offset(
A32, A_leading_dim, uint2(A_offset), A_trans);
auto A_src = simdgroup_matrix_storage<half>::apply_offset(
A, A_leading_dim, uint2(A_offset), A_trans);
auto B_src = simdgroup_matrix_storage<half>::apply_offset(
B, B_leading_dim, uint2(B_offset), B_trans);
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded / 2; m += 8) {
ushort2 origin(0, m);
auto A = A_sram(A_sram_allocation, origin);
A->load(A32_src, A_leading_dim, origin);
}
#pragma clang loop unroll(full)
for (ushort m = M_padded / 2; m < M_padded; m += 8) {
ushort2 origin(0, m);
auto A = A_sram(A_sram_allocation_2, origin);
A->load(A_src, A_leading_dim, origin);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, 0);
auto B = B_sram(B_sram_allocation, origin);
B->load_half(B_src, B_leading_dim, origin);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded / 2; m += 8) {
auto A = A_sram(A_sram_allocation, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(B_sram_allocation, ushort2(n, 0));
auto C = C_sram(C_sram_allocation, ushort2(n, m));
C->multiply(*A, *B);
}
}
#pragma clang loop unroll(full)
for (ushort m = M_padded / 2; m < M_padded; m += 8) {
auto A = A_sram(A_sram_allocation_2, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(B_sram_allocation, ushort2(n, 0));
auto C = C_sram(C_sram_allocation_2, ushort2(n, m));
C->multiply(*A, *B);
}
}
A_offset.x += 8;
B_offset.y += 8;
}
// WARNING: M and N must be divisible by 8.
{
uint2 matrix_origin = uint2(B_offset.x, A_offset.y);
auto C_src = simdgroup_matrix_storage<half>::apply_offset(
C, N, matrix_origin);
store_accumulator(C_sram_allocation, C_src, false, false);
store_accumulator_2(C_sram_allocation_2, C_src, false, false);
const uint M_edge_floor = M - M % M_simd;
const uint N_edge_floor = N - N % N_simd;
if (matrix_origin.y < M_edge_floor) {
store_accumulator(C_sram_allocation, C_src, true, false);
store_accumulator_2(C_sram_allocation_2, C_src, true, false);
}
if (matrix_origin.x < N_edge_floor) {
store_accumulator(C_sram_allocation, C_src, false, true);
store_accumulator_2(C_sram_allocation_2, C_src, false, true);
if (matrix_origin.y < M_edge_floor) {
store_accumulator(C_sram_allocation, C_src, true, true);
store_accumulator_2(C_sram_allocation_2, C_src, true, true);
}
}
}
}
"""
// MARK: - Script
func profileProblemSize(_ problemSize: Int) {
var A = [Float](repeating: .zero, count: problemSize * problemSize)
var B = [Float](repeating: .zero, count: problemSize * problemSize)
var C = [Float](repeating: .zero, count: problemSize * problemSize)
// Initialize A as the 2nd-order periodic Laplacian.
for diagonalID in 0..<problemSize {
let diagonalAddress = diagonalID * problemSize + diagonalID
A[diagonalAddress] = -2
let leftColumnID = (diagonalID + problemSize - 1) % problemSize
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID
A[leftSubDiagonalAddress] = 1
let rightColumnID = (diagonalID + problemSize + 1) % problemSize
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID
A[rightSubDiagonalAddress] = 1
}
// Initialize B to random numbers.
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = Float.random(in: 0..<1)
B[address] = entry
}
}
do {
// Initialize the context.
let context = MTLContext()
let library = try! context.device.makeLibrary(source: GEMM, options: nil)
// Set the function constants.
let constants = MTLFunctionConstantValues()
var M: Int = problemSize
var N: Int = problemSize
var K: Int = problemSize
var transpose: Bool = false
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&transpose, type: .bool, index: 10)
constants.setConstantValue(&transpose, type: .bool, index: 11)
var M_simd: UInt16 = 32
var N_simd: UInt16 = 32
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
let function = try! library.makeFunction(
name: "hgemm", constantValues: constants)
let pipeline = try! context.device
.makeComputePipelineState(function: function)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
let gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_simd),
height: ceilDivide(target: M, granularity: M_simd),
depth: 1)
let groupSize = MTLSize(
width: 32,
height: 1,
depth: 1)
// Create the buffers.
var bufferDesc = MTLBufferDescriptor()
bufferDesc.context = context
bufferDesc.problemSize = problemSize
bufferDesc.data = A
bufferDesc.dataType = MTLDataType.half
let bufferA = createMTLBuffer(descriptor: bufferDesc)
bufferDesc.data = B
bufferDesc.dataType = MTLDataType.half
let bufferB = createMTLBuffer(descriptor: bufferDesc)
bufferDesc.data = C
bufferDesc.dataType = MTLDataType.half
let bufferC = createMTLBuffer(descriptor: bufferDesc)
// Create the second buffer for A.
bufferDesc.data = A
bufferDesc.dataType = MTLDataType.float
let bufferA32 = createMTLBuffer(descriptor: bufferDesc)
// Profile the latency of matrix multiplication.
var bestStatistics: SIMD2<Int> = .init(.max, .zero)
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Execute the operation.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
encoder.setBuffer(bufferA32, offset: 0, index: 3)
for _ in 0..<duplicatedCommandCount {
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds)
bestStatistics[1] = max(bestStatistics[1], gflops)
}
// Report the results.
print("//", terminator: " ")
print("- problemSize =", problemSize, terminator: " ")
print("|", bestStatistics[1], "GFLOPS", terminator: " ")
print()
// Copy the results to C.
do {
let rawPointer = bufferC.contents()
let castedPointer = rawPointer.assumingMemoryBound(to: Float16.self)
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = castedPointer[address]
// Truncate the output's precision to 16-bit, while keeping the
// memory format of Float32.
// let bitPattern = UInt32(entry) << 16
// C[address] = Float(bitPattern: bitPattern)
// Read an accumulator stored as Float16 or Float32.
C[address] = Float(entry)
}
}
}
}
// Check the results.
for m in 0..<problemSize {
for n in 0..<problemSize {
// Find the source row IDs.
let leftRowID = (m + problemSize - 1) % problemSize
let centerRowID = m
let rightRowID = (m + problemSize + 1) % problemSize
// Find the source values.
let leftSource = B[leftRowID * problemSize + n]
let centerSource = B[centerRowID * problemSize + n]
let rightSource = B[rightRowID * problemSize + n]
// Find the expected and actual values.
let expected = leftSource - 2 * centerSource + rightSource
let actual = C[m * problemSize + n]
// Report the results.
let error = (expected - actual).magnitude
if error > 5e-2 {
print("error: \(error) / ~1.000")
}
}
}
}
// MARK: - Utilities
struct MTLContext {
var device: MTLDevice
var commandQueue: MTLCommandQueue
init() {
device = MTLCreateSystemDefaultDevice()!
commandQueue = device.makeCommandQueue()!
}
}
struct MTLBufferDescriptor {
var context: MTLContext?
var data: [Float]?
var problemSize: Int?
var dataType: MTLDataType?
}
func createMTLBuffer(
descriptor: MTLBufferDescriptor
) -> MTLBuffer {
guard let context = descriptor.context,
let data = descriptor.data,
let problemSize = descriptor.problemSize,
let dataType = descriptor.dataType else {
fatalError("Descriptor was invalid.")
}
// Check the size of the input array.
guard data.count == problemSize * problemSize else {
fatalError("Input was not a square matrix.")
}
// Allocate enough memory to store everything in Float32.
let buffer = context.device.makeBuffer(length: data.count * 4)!
// Copy the data into the buffer.
switch dataType {
case .half:
// Compress the precision of the data.
var compressedData = [Float16](repeating: .zero, count: data.count)
for elementID in data.indices {
let entry = data[elementID]
compressedData[elementID] = Float16(entry)
}
// Copy the data.
let pointer = buffer.contents().assumingMemoryBound(to: Float16.self)
pointer.initialize(from: compressedData, count: compressedData.count)
case .bfloat:
// Compress the precision of the data.
var compressedData = [UInt16](repeating: .zero, count: data.count)
for elementID in data.indices {
let entry = data[elementID]
// Truncate the input's precision to 16-bit, while keeping the memory
// format of Float32.
let bitPattern = entry.bitPattern
compressedData[elementID] = UInt16(bitPattern >> 16)
}
// Copy the data.
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self)
pointer.initialize(from: compressedData, count: compressedData.count)
case .float:
// Copy the data.
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
pointer.initialize(from: data, count: data.count)
default:
fatalError("Unrecognized data type.")
}
// Return the buffer.
return buffer
}
#if false
struct MPSMatrixStorage {
var buffer: MTLBuffer
var matrix: MPSMatrix
init(descriptor: MPSMatrixStorageDescriptor) {
guard let context = descriptor.context,
let data = descriptor.data,
let problemSize = descriptor.problemSize,
let dataType = descriptor.dataType else {
fatalError("Descriptor was invalid.")
}
// Create the buffer
buffer = createMTLBuffer(
context: context,
data: data,
problemSize: problemSize,
dataType: dataType)
// Set the descriptor properties.
let elementStride = (dataType == .float32) ? 4 : 2
let matrixDesc = MPSMatrixDescriptor(
rows: problemSize,
columns: problemSize,
rowBytes: problemSize * elementStride,
dataType: dataType)
// Create the matrix object.
matrix = MPSMatrix(buffer: buffer, descriptor: matrixDesc)
}
}
struct MPSTensorStorageDescriptor {
var context: MTLContext?
var data: [Float]?
var problemSize: Int?
var dataType: MPSDataType?
}
struct MPSTensorStorage {
var buffer: MTLBuffer
var tensorData: MPSGraphTensorData
init(descriptor: MPSTensorStorageDescriptor) {
guard let context = descriptor.context,
let data = descriptor.data,
let problemSize = descriptor.problemSize,
let dataType = descriptor.dataType else {
fatalError("Descriptor was invalid.")
}
// Create the buffer
buffer = createMTLBuffer(
context: context,
data: data,
problemSize: problemSize,
dataType: dataType)
// Create the tensor data.
let elementStride = (dataType == .float32) ? 4 : 2
tensorData = MPSGraphTensorData(
buffer,
shape: [NSNumber(value: problemSize), NSNumber(value: problemSize)],
dataType: dataType,
rowBytes: problemSize * elementStride)
}
}
#endif
// MARK: - Reference Code
#if false
func testCustomShaderPerformance() {
// Initialize the context.
let context = MTLContext()
let library = try! context.device.makeLibrary(source: GEMM, options: nil)
// Set the function constants.
let constants = MTLFunctionConstantValues()
var M: Int = problemSize
var N: Int = problemSize
var K: Int = problemSize
var transpose: Bool = false
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&transpose, type: .bool, index: 10)
constants.setConstantValue(&transpose, type: .bool, index: 11)
var M_simd: UInt16 = 32
var N_simd: UInt16 = 32
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
let function = try! library.makeFunction(
name: "hgemm", constantValues: constants)
let pipeline = try! context.device
.makeComputePipelineState(function: function)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
let gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_simd),
height: ceilDivide(target: M, granularity: M_simd),
depth: 1)
let groupSize = MTLSize(
width: 32,
height: 1,
depth: 1)
// Create the buffers.
func createBuffer(data: [Float], lowPrecision: Bool) -> MTLBuffer {
// Check the size of the input array.
guard data.count == problemSize * problemSize else {
fatalError("Input was not a square matrix.")
}
// Allocate enough memory to store everything in Float32.
let buffer = context.device.makeBuffer(length: data.count * 4)!
// Branch on whether the data is low-precision.
if lowPrecision {
// Compress the precision of the data.
var compressedData = [UInt16](repeating: .zero, count: data.count)
for elementID in data.indices {
let entry = data[elementID]
// Truncate the input's precision to 16-bit, while keeping the memory
// format of Float32.
// let bitPattern = entry.bitPattern
// compressedData[elementID] = UInt16(bitPattern >> 16)
// Write an input stored as Float16.
let bitPattern = Float16(entry).bitPattern
compressedData[elementID] = bitPattern
}
// Copy the data.
let pointer = buffer.contents().assumingMemoryBound(to: UInt16.self)
pointer.initialize(from: compressedData, count: compressedData.count)
} else {
// Copy the data.
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
pointer.initialize(from: data, count: data.count)
}
// Return the buffer object.
return buffer
}
let bufferA = createBuffer(data: A, lowPrecision: false)
let bufferB = createBuffer(data: B, lowPrecision: false)
let bufferC = createBuffer(data: C, lowPrecision: false)
// Profile the latency of matrix multiplication.
var bestStatistics: SIMD2<Int> = .init(.max, .zero)
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Execute the operation.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
for _ in 0..<duplicatedCommandCount {
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds)
bestStatistics[1] = max(bestStatistics[1], gflops)
}
}
#endif
#if false
func testMPSMatrixPerformance() {
// Initialize the context.
let context = MTLContext()
// Initialize the matrices.
var matrixStorageDesc = MPSMatrixStorageDescriptor()
matrixStorageDesc.context = context
matrixStorageDesc.problemSize = problemSize
matrixStorageDesc.dataType = .float32
matrixStorageDesc.data = A
let matrixA = MPSMatrixStorage(descriptor: matrixStorageDesc)
matrixStorageDesc.data = B
let matrixB = MPSMatrixStorage(descriptor: matrixStorageDesc)
matrixStorageDesc.data = C
let matrixC = MPSMatrixStorage(descriptor: matrixStorageDesc)
// Initialize the multiplication object.
let multiplication = MPSMatrixMultiplication(
device: context.device,
resultRows: problemSize,
resultColumns: problemSize,
interiorColumns: problemSize)
multiplication.leftMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
multiplication.rightMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
multiplication.resultMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
// Profile the latency of the matrix multiplication.
var bestStatistics: SIMD2<Int> = .init(.max, .zero)
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Execute the operation.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
for _ in 0..<duplicatedCommandCount {
multiplication.encode(
commandBuffer: commandBuffer,
leftMatrix: matrixA.matrix,
rightMatrix: matrixB.matrix,
resultMatrix: matrixC.matrix)
}
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds)
bestStatistics[1] = max(bestStatistics[1], gflops)
}
}
#endif
#if false
func testMPSGraphPerformance() {
// Initialize the context.
let context = MTLContext()
// Initialize the tensors.
var tensorStorageDesc = MPSTensorStorageDescriptor()
tensorStorageDesc.context = context
tensorStorageDesc.problemSize = problemSize
tensorStorageDesc.dataType = .float32
tensorStorageDesc.data = A
let tensorStorageA = MPSTensorStorage(descriptor: tensorStorageDesc)
let tensorStorageA2 = MPSTensorStorage(descriptor: tensorStorageDesc)
tensorStorageDesc.data = B
let tensorStorageB = MPSTensorStorage(descriptor: tensorStorageDesc)
let tensorStorageB2 = MPSTensorStorage(descriptor: tensorStorageDesc)
tensorStorageDesc.data = C
tensorStorageDesc.dataType = .bFloat16
let tensorStorageC = MPSTensorStorage(descriptor: tensorStorageDesc)
tensorStorageDesc.dataType = .float16
let tensorStorageC2 = MPSTensorStorage(descriptor: tensorStorageDesc)
// Create the graph executable.
var executable: MPSGraphExecutable
do {
let graph = MPSGraph()
// Program the matrix multiplication through the DSL.
let shape = [NSNumber(value: problemSize), NSNumber(value: problemSize)]
let tensorA32 = graph.placeholder(shape: shape, name: "Operand A (FP32)")
let tensorB32 = graph.placeholder(shape: shape, name: "Operand B (FP32)")
let tensorA16 = graph.cast(
tensorA32, to: .bFloat16, name: "Operand A (BF16)")
let tensorB16 = graph.cast(
tensorB32, to: .bFloat16, name: "Operand B (BF16)")
let tensorC = graph.matrixMultiplication(
primary: tensorA16, secondary: tensorB16, name: "Output")
// Program the second matrix multiplication through the DSL.
let tensorA232 = graph.placeholder(shape: shape, name: "Operand A (2nd)")
let tensorB232 = graph.placeholder(shape: shape, name: "Operand B (2nd)")
let tensorA216 = graph.cast(
tensorA232, to: .float16, name: "Operand A (2nd, FP16)")
let tensorB216 = graph.cast(
tensorB232, to: .float16, name: "Operand B (2nd, FP16)")
let tensorC2 = graph.matrixMultiplication(
primary: tensorA216, secondary: tensorB216, name: "Output (2nd)")
// Define the data type of each input to the graph.
let shapedType = MPSGraphShapedType(shape: shape, dataType: .float32)
let feeds: [MPSGraphTensor : MPSGraphShapedType] = [
tensorA32: shapedType,
tensorB32: shapedType,
tensorA232: shapedType,
tensorB232: shapedType,
]
let targetTensors: [MPSGraphTensor] = [
tensorC,
tensorC2,
]
// Set the compilation descriptor.
let compilationDesc = MPSGraphCompilationDescriptor()
compilationDesc.optimizationLevel = .level1
compilationDesc.waitForCompilationCompletion = true
// Create the graph executable.
let mpsGraphDevice = MPSGraphDevice(
mtlDevice: context.device)
executable = graph.compile(
with: mpsGraphDevice,
feeds: feeds,
targetTensors: targetTensors,
targetOperations: nil,
compilationDescriptor: compilationDesc)
// Specialize the executable.
executable.specialize(
with: mpsGraphDevice,
inputTypes: [
shapedType,
shapedType,
shapedType,
shapedType,
],
compilationDescriptor: compilationDesc)
}
// Profile the latency of matrix multiplication.
var bestStatistics: SIMD2<Int> = .init(.max, .zero)
let trialCount = (problemSize >= 3072) ? 5 : 15
let duplicatedCommandCount = (problemSize >= 3072) ? 5 : 20
for _ in 0..<trialCount {
// Create the execution descriptor.
let executionDesc = MPSGraphExecutableExecutionDescriptor()
executionDesc.waitUntilCompleted = false
// Execute the operation.
let start = CACurrentMediaTime()
let commandBuffer = MPSCommandBuffer(from: context.commandQueue)
for _ in 0..<duplicatedCommandCount {
executable.encode(
to: commandBuffer,
inputs: [
tensorStorageA.tensorData,
tensorStorageB.tensorData,
tensorStorageA2.tensorData,
tensorStorageB2.tensorData,
],
results: [
tensorStorageC.tensorData,
tensorStorageC2.tensorData,
],
executionDescriptor: executionDesc)
}
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
let end = CACurrentMediaTime()
// Determine the time taken.
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations *= 2 // FP16 and FP32 simultaneously
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
bestStatistics[0] = min(bestStatistics[0], latencyMicroseconds)
bestStatistics[1] = max(bestStatistics[1], gflops)
}
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment