Last active
May 28, 2024 00:00
-
-
Save philipturner/fe0621333ab6a827ed0d35c5fa226be6 to your computer and use it in GitHub Desktop.
Hacking AIR to open up the Apple GPU to general-purpose linear algebra
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// main.swift | |
// HackingAIR | |
// | |
// Created by Philip Turner on 5/27/24. | |
// | |
import Metal | |
// Hacking AIR to open up the Apple GPU to general-purpose linear algebra. | |
// | |
// ========================================================================== // | |
// Methods | |
// ========================================================================== // | |
// | |
// The shader code 'metalSimdgroupEvent' makes all five async copy instructions | |
// accessible to the Metal JIT compiler. I cannot stress how important this is | |
// for M1 and M2 users. | |
// - Asynchronously copy from device -> threadgroup, 1D. | |
// - Asynchronously copy from threadgroup -> device, 1D. | |
// - Asynchronously copy from device -> threadgroup, 2D. | |
// - Asynchronously copy from threadgroup -> device, 2D. | |
// - Wait on a simdgroup event. | |
// | |
// I have not tested the correctness of the exposed 1D async copy functions. | |
// Only the 2D async copies, which are required for matrix multiplication. | |
// | |
// Insights about invoking the generation of LLVM bitcode. | |
// - To add 'nounwind' to a function, prepend the declaration with: | |
// __attribute__((__nothrow__)) | |
// - To add 'nocapture' to an argument, add the following after the argument | |
// type and name: | |
// __attribute__((__noescape__)) | |
// - Not yet figured out the 'argmemonly' attribute, but my intuition tells me | |
// there's a hidden way to activate it with plain C. | |
// - AIR seems to accept the functions without these annotations being added. | |
// Therefore, I omitted them from the final product to make it more legible. | |
// | |
// ========================================================================== // | |
// Results (Correctness) | |
// ========================================================================== // | |
// | |
// Checked against the analytical results for a Laplacian matrix multiplication. | |
// Agrees with analytical results to within Float32 rounding error. | |
// | |
// ========================================================================== // | |
// Results (Performance) | |
// ========================================================================== // | |
// | |
// M1 Max, macOS Sonoma, Metal JIT Compiler | |
// - problemSize = 64 | 45 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 96 | 112 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 128 | 217 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 192 | 503 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 256 | 886 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 384 | 3448 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 512 | 5314 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 640 | 6533 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 768 | 7030 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 896 | 7154 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1024 | 7001 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1152 | 8145 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1280 | 7810 GFLOPS (48x48x24, 128 threads/group) | |
// | |
// M4, iOS 17, Metal JIT Compiler | |
// - problemSize = 64 | 14 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 96 | 80 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 128 | 162 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 192 | 369 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 256 | 620 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 384 | 1256 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 512 | 1746 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 640 | 1960 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 768 | 1955 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 896 | 2040 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1024 | 2076 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1152 | 2118 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1280 | 2130 GFLOPS (32x32x8, 32 threads/group) | |
// | |
// Compare to reference data for a pre-compiled shader with async copy. | |
// https://gist.github.com/philipturner/3bda14e876a635e73745c42f2eb240c8 | |
// | |
// M1 Max | |
// - problemSize = 64 | 45 vs 44 GFLOPS | |
// - problemSize = 96 | 112 vs 109 GFLOPS | |
// - problemSize = 128 | 217 vs 214 GFLOPS | |
// - problemSize = 192 | 503 vs 495 GFLOPS | |
// - problemSize = 256 | 886 vs 883 GFLOPS | |
// - problemSize = 384 | 3448 vs 2912 GFLOPS | |
// - problemSize = 512 | 5314 vs 4091 GFLOPS | |
// - problemSize = 640 | 6533 vs 6440 GFLOPS | |
// - problemSize = 768 | 7030 vs 7017 GFLOPS | |
// - problemSize = 896 | 7154 vs 7136 GFLOPS | |
// - problemSize = 1024 | 7001 vs 6966 GFLOPS | |
// - problemSize = 1152 | 8145 vs 8144 GFLOPS | |
// - problemSize = 1280 | 7810 vs 7813 GFLOPS | |
// | |
// M4 | |
// - problemSize = 64 | 14 vs 39 GFLOPS | |
// - problemSize = 96 | 80 vs 32 GFLOPS | |
// - problemSize = 128 | 162 vs 94 GFLOPS | |
// - problemSize = 192 | 369 vs 364 GFLOPS | |
// - problemSize = 256 | 620 vs 654 GFLOPS | |
// - problemSize = 384 | 1256 vs 1270 GFLOPS | |
// - problemSize = 512 | 1746 vs 1626 GFLOPS | |
// - problemSize = 640 | 1960 vs 1947 GFLOPS | |
// - problemSize = 768 | 1955 vs 1955 GFLOPS | |
// - problemSize = 896 | 2040 vs 2034 GFLOPS | |
// - problemSize = 1024 | 2076 vs 2078 GFLOPS | |
// - problemSize = 1152 | 2118 vs 2119 GFLOPS | |
// - problemSize = 1280 | 2130 vs 2129 GFLOPS | |
// | |
// For M1 Max, performance is the same. This means we are triggering the same | |
// AIR codegen as Xcode 14.2. | |
// | |
// For M4, the first few numbers appeared to match MFA without async copy. | |
// However, the remaining numbers match MFA with async copy. The numbers for | |
// large matrices are more important, as they reveal the maximum performance | |
// attainable with a specific design choice. ALU utilization is ~61%, | |
// dismal compared to ~91% achieved without async copy. This result means | |
// async copy instructions are correctly generated on iOS. We will need them | |
// for the older A14, A15, and A16 chips. | |
// | |
// ========================================================================== // | |
// Discussion | |
// ========================================================================== // | |
// | |
// I now have hope that I can access hardware-accelerated matrix multiplication | |
// from OpenCL. Therefore, I could target both Apple GPUs and the AMD 7900 XTX | |
// from the same shader source. It will require a bit of extra time to adapt | |
// this bitcode injection to the cl2Metal compiler. I'm deferring that | |
// investigation to a later date. | |
// | |
// By itself, this workaround does not solve the problem of people needing to | |
// rely on MPS for matrix multiplications. There needs to be an accompanying | |
// reference implementation, of a near-optimal matrix multiplication kernel for | |
// all hardware (M1/M2 and M3/M4). In addition, this code only covers Float32. | |
// It is the hardest data type to run at max ALU% because it requires | |
// so much L1 bandwidth. If this data type can run at maximum performance, | |
// extending the code to Float16 should not be difficult. | |
// | |
// For BFloat16, there should be graphs and statistics about the inferior | |
// performance of MSL's built-in type / MLX's decoding algorithm on M1. It | |
// should provide an alternative algorithm for BF16 -> FP32 decoding, and | |
// strongly encourage users to activate it on M1. In addition, users should be | |
// conscious about how BF16 numbers are stored in registers on M1. BF16 numbers | |
// will most likely be stored in FP32 registers to avoid redundantly decoding | |
// every time they are accessed. | |
// MARK: - Shader Sources | |
let metalSimdgroupEvent: String = """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_event ---------------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_EVENT | |
#define __METAL_SIMDGROUP_EVENT | |
#if !defined(__HAVE_SIMDGROUP_FUTURE__) | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %struct._simdgroup_event_t = type opaque | |
// | |
struct _simdgroup_event_t; | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_1d( | |
ulong, ulong, threadgroup void *, const device void *, ulong) | |
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_1d( | |
ulong, ulong, device void *, const threadgroup void *, ulong) | |
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p3i8.p1i8( | |
// i64, i64, i8 addrspace(3)* nocapture writeonly, | |
// i64, i64, <2 x i64>, i8 addrspace(1)* nocapture readonly, | |
// i64, i64, <2 x i64>, <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, threadgroup void *, | |
ulong, ulong, ulong2, const device void *, | |
ulong, ulong, ulong2, long2, int) | |
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: argmemonly convergent nounwind | |
// declare %struct._simdgroup_event_t* | |
// @air.simdgroup_async_copy_2d.p1i8.p3i8( | |
// i64, i64, i8 addrspace(1)* nocapture writeonly, | |
// i64, i64, <2 x i64>, i8 addrspace(3)* nocapture readonly, | |
// i64, i64, <2 x i64>, <2 x i64>, i32) | |
// local_unnamed_addr #4 | |
// | |
thread _simdgroup_event_t* | |
__metal_simdgroup_async_copy_2d( | |
ulong, ulong, device void *, | |
ulong, ulong, ulong2, const threadgroup void *, | |
ulong, ulong, ulong2, long2, int) | |
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// ; Function Attrs: convergent nounwind | |
// declare void | |
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) | |
// local_unnamed_addr #3 | |
// | |
void __metal_wait_simdgroup_events( | |
int, thread _simdgroup_event_t**) | |
__asm("air.wait_simdgroup_events"); | |
#endif | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
#if !defined(__HAVE_SIMDGROUP_FUTURE__) | |
enum class simdgroup_async_copy_clamp_mode { | |
clamp_to_zero = 0, | |
clamp_to_edge = 1 | |
}; | |
#endif | |
struct simdgroup_event { | |
METAL_FUNC simdgroup_event() thread {} | |
template <typename T> | |
METAL_FUNC void async_copy(threadgroup T *dst, const device T *src, ulong n_elements) thread { | |
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), reinterpret_cast<const device void *>(src), n_elements); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(device T *dst, const threadgroup T *src, ulong n_elements) thread { | |
event = __metal_simdgroup_async_copy_1d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), reinterpret_cast<const threadgroup void *>(src), n_elements); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(threadgroup T *dst, ushort dst_elements_per_row, ushort2 dst_tile_dimensions, const device T *src, uint src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false, simdgroup_async_copy_clamp_mode clamp_mode = simdgroup_async_copy_clamp_mode::clamp_to_zero) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<threadgroup void *>(dst), ushort(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const device void *>(src), uint(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), static_cast<int>(clamp_mode)); | |
} | |
template <typename T> | |
METAL_FUNC void async_copy(device T *dst, uint dst_elements_per_row, ushort2 dst_tile_dimensions, const threadgroup T *src, ushort src_elements_per_row, ushort2 src_tile_dimensions, bool transpose_matrix = false) thread { | |
if (transpose_matrix) { | |
src_tile_dimensions = src_tile_dimensions.yx; | |
dst_tile_dimensions = dst_tile_dimensions.yx; | |
} | |
event = __metal_simdgroup_async_copy_2d(sizeof(T), alignof(T), reinterpret_cast<device void *>(dst), uint(dst_elements_per_row), 1, ulong2(dst_tile_dimensions), reinterpret_cast<const threadgroup void *>(src), ushort(src_elements_per_row), 1, ulong2(src_tile_dimensions), long2(0), 0); | |
} | |
METAL_FUNC static void wait(int count, thread simdgroup_event *events) { | |
#if defined(__HAVE_SIMDGROUP_FUTURE__) | |
__metal_wait_simdgroup_events(count, reinterpret_cast<thread __metal_simdgroup_event_t*>(events)); | |
#else | |
__metal_wait_simdgroup_events(count, reinterpret_cast<thread _simdgroup_event_t**>(events)); | |
#endif | |
} | |
private: | |
// Invoking the generation of LLVM bitcode for async copies. | |
// | |
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } | |
// | |
#if defined(__HAVE_SIMDGROUP_FUTURE__) | |
__metal_simdgroup_event_t event; | |
#else | |
thread _simdgroup_event_t* event; | |
#endif | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif // __METAL_SIMDGROUP_EVENT | |
""" | |
let metalSimdgroupMatrixStorage: String = """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_matrix_storage ------------------------------------===// | |
// Copyright (c) 2024 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE | |
#define __METAL_SIMDGROUP_MATRIX_STORAGE | |
// Contains C++ symbols accessible to a developer through automatic code | |
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard | |
// Library for consistency with other Metal code. | |
#if defined(__HAVE_SIMDGROUP_MATRIX__) | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
template <typename T> | |
struct simdgroup_matrix_storage { | |
typedef vec<T, 64> storage_type; | |
storage_type t; | |
METAL_FUNC thread vec<T, 2>* thread_elements() thread { | |
return reinterpret_cast<thread vec<T, 2>*>(&t); | |
} | |
METAL_FUNC simdgroup_matrix_storage() thread = default; | |
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread { | |
*(this->thread_elements()) = thread_elements; | |
} | |
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) { | |
// https://patents.google.com/patent/US11256518B2 | |
ushort lane_id = thread_index_in_simdgroup; | |
ushort quad_id = lane_id / 4; | |
constexpr ushort QUADRANT_SPAN_M = 4; | |
constexpr ushort THREADS_PER_QUADRANT = 8; | |
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; | |
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); | |
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; | |
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 | |
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 | |
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; | |
return ushort2(N_in_simd, M_in_simd); | |
} | |
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; | |
} else { | |
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; | |
} | |
} | |
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + matrix_origin.x * elements_per_row + matrix_origin.y; | |
} else { | |
return src + matrix_origin.y * elements_per_row + matrix_origin.x; | |
} | |
} | |
// WARNING: All load and store functions assume the X dimension is divisible by 2. | |
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0]; | |
} | |
} | |
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1]; | |
} | |
} | |
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0]; | |
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
template <typename U, typename V> | |
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) { | |
if (!accumulate) { | |
*(thread_elements()) = vec<T, 2>(0); | |
} | |
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type()); | |
} | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif | |
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE | |
""" | |
let GEMM = """ | |
// | |
// GEMM.metal | |
// MetalFlashAttention | |
// | |
// Created by Philip Turner on 6/23/23. | |
// | |
#include <metal_stdlib> | |
\(metalSimdgroupEvent) | |
\(metalSimdgroupMatrixStorage) | |
using namespace metal; | |
// MARK: - Function Constants | |
// Dimensions of each matrix. | |
constant uint M [[function_constant(0)]]; | |
constant uint N [[function_constant(1)]]; | |
constant uint K [[function_constant(2)]]; | |
// Whether each matrix is transposed. | |
constant bool A_trans [[function_constant(10)]]; | |
constant bool B_trans [[function_constant(11)]]; | |
constant bool D_trans [[function_constant(13)]]; | |
constant uint A_leading_dim = A_trans ? M : K; | |
constant uint B_leading_dim = B_trans ? K : N; | |
// Alpha and beta constants from BLAS. | |
constant float alpha [[function_constant(20)]]; | |
constant float beta [[function_constant(21)]]; | |
constant ushort M_simd [[function_constant(200)]]; | |
constant ushort N_simd [[function_constant(201)]]; | |
constant ushort K_simd [[function_constant(202)]]; | |
// Elide work on the edge when matrix dimension < SRAM block dimension. | |
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd); | |
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd); | |
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd; | |
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; | |
constant ushort M_splits [[function_constant(210)]]; | |
constant ushort N_splits [[function_constant(211)]]; | |
constant ushort M_group = M_simd * M_splits; | |
constant ushort N_group = N_simd * N_splits; | |
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd); | |
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group); | |
// There is no padding for M reads/writes. | |
// There is no padding for N reads/writes. | |
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd); | |
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8; | |
constant ushort A_sram_length = (M_simd / 8) * 1; | |
constant ushort B_sram_length = 1 * (N_simd / 8); | |
constant ushort A_block_length = M_group * K_simd; | |
// Threadgroup block must fit entire C accumulator and partial sums. | |
constant ushort A_sram_offset = 0; | |
constant ushort B_sram_offset = A_sram_offset + A_sram_length; | |
constant ushort C_sram_offset = B_sram_offset + B_sram_length; | |
constant ushort A_block_offset = 0; | |
constant ushort B_block_offset = A_block_offset + A_block_length; | |
// MARK: - Utilities | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[M_simd][8] | |
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[8][N_simd] | |
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// C_sram[M_simd][N_simd] | |
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A, | |
ushort2 A_tile_src, uint2 A_offset, | |
threadgroup T *B_block, device T *B, | |
ushort2 B_tile_src, uint2 B_offset, uint k) | |
{ | |
A_tile_src.x = min(uint(K_simd), K - k); | |
B_tile_src.y = min(uint(K_simd), K - k); | |
auto A_src = simdgroup_matrix_storage<T>::apply_offset( | |
A, A_leading_dim, A_offset, A_trans); | |
auto B_src = simdgroup_matrix_storage<T>::apply_offset( | |
B, B_leading_dim, B_offset, B_trans); | |
// Rounded-up ceiling for the threadgroup block. | |
const uint K_edge_floor = K - K_simd_unpadded; | |
const uint K_edge_ceil = K_edge_floor + K_simd_padded; | |
ushort K_padded; | |
if (K_edge_floor == K_simd) { | |
K_padded = K_simd; | |
} else { | |
K_padded = min(uint(K_simd), K_edge_ceil - k); | |
} | |
ushort2 A_tile_dst(K_padded, A_tile_src.y); | |
ushort2 B_tile_dst(B_tile_src.x, K_padded); | |
simdgroup_event events[2]; | |
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, | |
A_leading_dim, A_tile_src, A_trans); | |
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, | |
B_leading_dim, B_tile_src, B_trans); | |
simdgroup_event::wait(2, events); | |
} | |
// One iteration of the MACC loop, effectively k=8 iterations. | |
template <typename T> | |
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram, | |
const threadgroup T *A_block, | |
const threadgroup T *B_block, | |
bool accumulate = true) | |
{ | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
ushort2 origin(0, m); | |
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, 0); | |
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
auto A = A_sram(sram, ushort2(0, m)); | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto B = B_sram(sram, ushort2(n, 0)); | |
auto C = C_sram(sram, ushort2(n, m)); | |
C->multiply(*A, *B, accumulate); | |
} | |
} | |
} | |
template <typename T> | |
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram, | |
threadgroup T *C_block, bool is_k_summation) | |
{ | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, m); | |
if (is_k_summation) { | |
C_sram(sram, origin)->store(C_block, N_simd, origin); | |
} else { | |
C_sram(sram, origin)->store(C_block, N_group, origin); | |
} | |
} | |
} | |
} | |
template <typename T> | |
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram, | |
threadgroup T *C_block, bool is_k_summation) | |
{ | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, m); | |
auto B = B_sram(sram, ushort2(n, 0)); | |
if (is_k_summation) { | |
B->load(C_block, N_simd, origin); | |
} else { | |
B->load(C_block, N_group, origin); | |
} | |
} | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, m); | |
auto B = B_sram(sram, ushort2(n, 0)); | |
auto C = C_sram(sram, origin); | |
if (is_k_summation) { | |
C->thread_elements()[0] += B->thread_elements()[0]; | |
} else { | |
float2 C_old = float2(B->thread_elements()[0]); | |
float2 C_new = float2(C->thread_elements()[0]); | |
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new)); | |
} | |
} | |
} | |
} | |
template <typename T> | |
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C, | |
uint2 C_offset, bool is_store) | |
{ | |
ushort2 C_tile(min(uint(N_group), N - C_offset.x), | |
min(uint(M_group), M - C_offset.y)); | |
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset); | |
simdgroup_event event; | |
if (is_store) { | |
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile); | |
} else { | |
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile); | |
simdgroup_event::wait(1, &event); | |
} | |
} | |
template <typename T> | |
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram, | |
device T *C, bool m_is_edge, bool n_is_edge) | |
{ | |
const ushort m_start = (m_is_edge) ? M_modulo : 0; | |
const ushort n_start = (n_is_edge) ? N_modulo : 0; | |
const ushort m_end = (m_is_edge) ? M_simd : M_modulo; | |
const ushort n_end = (n_is_edge) ? N_simd : N_modulo; | |
#pragma clang loop unroll(full) | |
for (ushort m = m_start; m < m_end; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = n_start; n < n_end; n += 8) { | |
ushort2 origin(n, m); | |
C_sram(sram, origin)->store(C, N, origin); | |
} | |
} | |
} | |
// MARK: - Kernels | |
kernel void sgemm(device float *A [[buffer(0)]], | |
device float *B [[buffer(1)]], | |
device float *C [[buffer(2)]], | |
threadgroup float *threadgroup_block [[threadgroup(0)]], | |
uint3 gid [[threadgroup_position_in_grid]], | |
ushort sidx [[simdgroup_index_in_threadgroup]], | |
ushort lane_id [[thread_index_in_simdgroup]]) | |
{ | |
simdgroup_matrix_storage<float> sram[1024]; | |
auto A_block = threadgroup_block + A_block_offset; | |
auto B_block = threadgroup_block + B_block_offset; | |
ushort2 sid(sidx % N_splits, sidx / N_splits); | |
ushort2 offset_in_simd = simdgroup_matrix_storage<float>::offset(lane_id); | |
uint2 A_offset(0, gid.y * M_group); | |
uint2 B_offset(gid.x * N_group, 0); | |
{ | |
uint C_base_offset_x = B_offset.x + sid.x * N_simd; | |
uint C_base_offset_y = A_offset.y + sid.y * M_simd; | |
if (C_base_offset_x >= N || C_base_offset_y >= M) { | |
return; | |
} | |
} | |
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x, | |
sid.y * M_simd + offset_in_simd.y); | |
ushort2 A_tile_src; | |
ushort2 B_tile_src; | |
if (sidx == 0) { | |
A_tile_src.y = min(uint(M_group), M - A_offset.y); | |
B_tile_src.x = min(uint(N_group), N - B_offset.x); | |
prefetch(A_block, A, A_tile_src, A_offset, | |
B_block, B, B_tile_src, B_offset, 0); | |
} | |
if (K > K_simd) { | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<float>(0); | |
} | |
} | |
} | |
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) { | |
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y); | |
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y); | |
auto A_block_src = simdgroup_matrix_storage<float>::apply_offset( | |
A_block, A_block_leading_dim, A_block_offset, A_trans); | |
auto B_block_src = simdgroup_matrix_storage<float>::apply_offset( | |
B_block, B_block_leading_dim, B_block_offset, B_trans); | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
#pragma clang loop unroll(full) | |
for (ushort k = 0; k < K_simd_padded; k += 8) { | |
bool accumulate = !(K <= K_simd && k == 0); | |
multiply_accumulate(sram, A_block_src, B_block_src, accumulate); | |
A_block_src += A_trans ? 8 * A_block_leading_dim : 8; | |
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; | |
} | |
if (K_floor + K_simd < K) { | |
#pragma clang loop unroll(full) | |
for (ushort k = K_simd_padded; k < K_simd; k += 8) { | |
multiply_accumulate(sram, A_block_src, B_block_src); | |
A_block_src += A_trans ? 8 * A_block_leading_dim : 8; | |
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (sidx == 0) { | |
uint K_next = K_floor + K_simd; | |
A_offset.x = K_next; | |
B_offset.y = K_next; | |
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next); | |
} | |
} | |
} | |
if (alpha != 1) { | |
#pragma clang loop unroll(full) | |
for (int m = 0; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (int n = 0; n < N_padded; n += 8) { | |
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= alpha; | |
} | |
} | |
} | |
uint2 C_offset(B_offset.x, A_offset.y); | |
ushort2 C_block_offset = offset_in_group.xy; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (beta != 0) { | |
if (sidx == 0) { | |
async_access_accumulator(threadgroup_block, C, C_offset, false); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
auto C_block = simdgroup_matrix_storage<float>::apply_offset( | |
threadgroup_block, N_group, C_block_offset); | |
partial_accumulate(sram, C_block, false); | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
} | |
if ((M % 8 != 0) || (N % 8 != 0)) { | |
auto C_block = simdgroup_matrix_storage<float>::apply_offset( | |
threadgroup_block, N_group, C_block_offset); | |
partial_store(sram, C_block, false); | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (sidx == 0) { | |
async_access_accumulator(threadgroup_block, C, C_offset, true); | |
} | |
} else { | |
uint2 matrix_origin = C_offset + uint2(C_block_offset); | |
auto C_src = simdgroup_matrix_storage<float>::apply_offset( | |
C, N, matrix_origin); | |
store_accumulator(sram, C_src, false, false); | |
const uint M_edge_floor = M - M % M_simd; | |
const uint N_edge_floor = N - N % N_simd; | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(sram, C_src, true, false); | |
} | |
if (matrix_origin.x < N_edge_floor) { | |
store_accumulator(sram, C_src, false, true); | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(sram, C_src, true, true); | |
} | |
} | |
} | |
} | |
""" | |
// MARK: - Utilities | |
struct MTLContext { | |
var device: MTLDevice | |
var commandQueue: MTLCommandQueue | |
init() { | |
device = MTLCreateSystemDefaultDevice()! | |
commandQueue = device.makeCommandQueue()! | |
} | |
} | |
#if os(macOS) | |
public class MTLCompiler { | |
public internal(set) var buildProductsDirectory: URL | |
var xcodePath: String = "/Applications/Xcode.app" | |
public init() { | |
// let temporaryDirectory = FileManager.default.temporaryDirectory | |
let temporaryDirectory = URL(filePath: "/Users/philipturner/Desktop") | |
buildProductsDirectory = temporaryDirectory | |
.appendingPathComponent("metal-compiler") | |
} | |
public func setXcodePath(_ path: String) { | |
self.xcodePath = path | |
} | |
// This has about 200 ms latency, so invoke it sparingly. In the final | |
// version of this library, we'll compile all GPU shaders beforehand and cache | |
// the Metal resource objects at runtime. | |
// | |
// Near-term, create a key-value cache of the shader sources to the | |
// compiled objects. This could speed up some workflows during development of | |
// the library. | |
public func compile(_ source: String) -> MTLLibrary { | |
// | |
// build.swift | |
// MetalFlashAttention | |
// | |
// Created by Philip Turner on 6/26/23. | |
// | |
// MARK: - Parse Arguments | |
struct BuildSettings { | |
var externalMetallibPath: String? = nil | |
var platform: Platform? = nil | |
var verbose: Bool = false | |
var xcodePath: String = "" | |
enum Platform { | |
case iOS | |
case macOS | |
var metalToolsPath: String { | |
switch self { | |
case .iOS: | |
return "ios" | |
case .macOS: | |
return "macos" | |
} | |
} | |
var deploymentVersionArgument: String { | |
switch self { | |
case .iOS: | |
return "-mios-version-min=16.0.0" | |
case .macOS: | |
return "-mmacosx-version-min=13.0.0" | |
} | |
} | |
var xcrunSDK: String { | |
switch self { | |
case .iOS: | |
return "iphoneos" | |
case .macOS: | |
return "macosx" | |
} | |
} | |
} | |
func metalToolPath(executable: String) -> String { | |
guard let metalToolsPath = platform?.metalToolsPath else { | |
fatalError("Must specify platform before locating Metal tools.") | |
} | |
var output = xcodePath | |
output += "/Contents/Developer/Toolchains/XcodeDefault.xctoolchain" | |
output += "/usr/metal/\(metalToolsPath)/bin/\(executable)" | |
return output | |
} | |
func xcrunMetalArguments(executable: String) -> [String] { | |
guard let xcrunSDK = platform?.xcrunSDK else { | |
fatalError("Must specify platform before locating Metal tools.") | |
} | |
return ["-sdk", xcrunSDK, executable] | |
} | |
} | |
var settings = BuildSettings() | |
settings.verbose = Bool.random() ? true : true | |
settings.platform = .macOS | |
settings.xcodePath = self.xcodePath | |
// MARK: - Prepare File Directories | |
func directoryExists(url: URL) -> Bool { | |
var isDirectory: ObjCBool = false | |
let succeeded = FileManager.default.fileExists( | |
atPath: url.path, isDirectory: &isDirectory) | |
return succeeded && isDirectory.boolValue | |
} | |
func fileExists(url: URL) -> Bool { | |
var isDirectory: ObjCBool = false | |
let succeeded = FileManager.default.fileExists( | |
atPath: url.path, isDirectory: &isDirectory) | |
return succeeded && !isDirectory.boolValue | |
} | |
func assertDirectoryExists(url: URL, line: UInt = #line) { | |
guard directoryExists(url: url) else { | |
fatalError(""" | |
Line \(line): | |
Directory not found at '\(url.path)'. | |
""") | |
} | |
} | |
func assertFileExists(url: URL, line: UInt = #line) { | |
guard fileExists(url: url) else { | |
fatalError(""" | |
Line \(line): | |
File not found at '\(url.path)'. | |
""") | |
} | |
} | |
func touchDirectory(url: URL) { | |
if !directoryExists(url: url) { | |
try! FileManager.default.createDirectory( | |
at: url, withIntermediateDirectories: false) | |
} | |
assertDirectoryExists(url: url) | |
} | |
let workDir = self.buildProductsDirectory | |
touchDirectory(url: workDir) | |
let buildDir = workDir.appending(component: "build") | |
let libDir = buildDir.appending(component: "lib") | |
let srcDir = buildDir.appending(component: "src") | |
touchDirectory(url: buildDir) | |
touchDirectory(url: libDir) | |
touchDirectory(url: srcDir) | |
func createFile(source: String, name: String) { | |
guard let data = source.data(using: .utf8) else { | |
fatalError("Could not encode source string as UTF-8.") | |
} | |
let destinationURL = srcDir.appendingPathComponent(name) | |
try! data.write(to: destinationURL) | |
} | |
createFile( | |
source: metalSimdgroupEvent, | |
name: "metal_simdgroup_event") | |
createFile( | |
source: metalSimdgroupMatrixStorage, | |
name: "metal_simdgroup_matrix_storage") | |
createFile( | |
source: source, | |
name: "File.metal") | |
let metalURL = srcDir.appendingPathComponent("File.metal") | |
let airURL = buildDir.appendingPathComponent("File.air") | |
let airPath = airURL.relativePath | |
try? FileManager.default.removeItem(atPath: airPath) | |
guard FileManager.default.createFile( | |
atPath: airPath, contents: nil) else { | |
fatalError("Could not create destination path '\(airPath)'.") | |
} | |
// MARK: - Compile AIR File | |
// Arguments to invoke the Metal compiler with. | |
var arguments: [String] = [] | |
arguments.append(settings.platform!.deploymentVersionArgument) | |
arguments.append("-c") | |
// Suppress compiler warnings unless the user enters '--verbose'. | |
if settings.verbose { | |
arguments.append("-Wno-unused-function") | |
arguments.append("-Wno-unused-variable") | |
} else { | |
arguments.append("-w") | |
} | |
let process = Process() | |
let toolPath = settings.metalToolPath(executable: "metal") | |
process.executableURL = URL(filePath: toolPath) | |
process.arguments = arguments + [metalURL.path, "-o", airPath] | |
try! process.run() | |
process.waitUntilExit() | |
if process.terminationStatus != 0 { | |
fatalError("Could not compile source.") | |
} | |
// MARK: - Compile Metal Library | |
arguments = [] | |
// Package the metallib using the up-to-date Xcode version. | |
func runProcess() { | |
let process = try! Process.run( | |
URL(fileURLWithPath: "/usr/bin/xcrun"), | |
arguments: arguments) | |
process.waitUntilExit() | |
} | |
let metallibName = "File.metallib" | |
let metallibURL = libDir.appending(component: metallibName) | |
arguments = [] | |
arguments += settings.xcrunMetalArguments(executable: "metal") | |
arguments.append(settings.platform!.deploymentVersionArgument) | |
arguments.append(airPath) | |
arguments.append("-o") | |
arguments.append(metallibURL.path) | |
runProcess() | |
// MARK: - Instantiate Resource Object | |
let device = MTLCreateSystemDefaultDevice()! | |
let library = try! device.makeLibrary(URL: metallibURL) | |
//try! FileManager.default.removeItem(at: workDir) | |
return library | |
} | |
} | |
#endif | |
// MARK: - Script | |
func runApplication() { | |
print("Hello, console.") | |
let problemSize: Int = 64 | |
var A = [Float](repeating: .zero, count: problemSize * problemSize) | |
var B = [Float](repeating: .zero, count: problemSize * problemSize) | |
var C = [Float](repeating: .zero, count: problemSize * problemSize) | |
// Initialize A as the 2nd-order periodic Laplacian. | |
for diagonalID in 0..<problemSize { | |
let diagonalAddress = diagonalID * problemSize + diagonalID | |
A[diagonalAddress] = -2 | |
let leftColumnID = (diagonalID + problemSize - 1) % problemSize | |
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID | |
A[leftSubDiagonalAddress] = 1 | |
let rightColumnID = (diagonalID + problemSize + 1) % problemSize | |
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID | |
A[rightSubDiagonalAddress] = 1 | |
} | |
// Initialize B to random numbers. | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = Float.random(in: 0..<1) | |
B[address] = entry | |
} | |
} | |
// Multiply A with B. | |
#if false | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
var dotProduct: Float = .zero | |
for k in 0..<problemSize { | |
// Read from the input matrices. | |
let lhsAddress = m * problemSize + k | |
let rhsAddress = k * problemSize + n | |
let lhsValue = A[lhsAddress] | |
let rhsValue = B[rhsAddress] | |
dotProduct += lhsValue * rhsValue | |
} | |
// Write to the output matrix. | |
let address = m * problemSize + n | |
C[address] = dotProduct | |
} | |
} | |
#else | |
do { | |
// Initialize the context. | |
let context = MTLContext() | |
// let compiler = MTLCompiler() | |
// let library = compiler.compile(GEMM) | |
let library = try! context.device.makeLibrary(source: GEMM, options: nil) | |
// Set the function constants. | |
let constants = MTLFunctionConstantValues() | |
var M: Int = problemSize | |
var N: Int = problemSize | |
var K: Int = problemSize | |
var transpose: Bool = false | |
var alpha: Float = 1 | |
var beta: Float = 0 | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
constants.setConstantValue(&transpose, type: .bool, index: 10) | |
constants.setConstantValue(&transpose, type: .bool, index: 11) | |
constants.setConstantValue(&alpha, type: .float, index: 20) | |
constants.setConstantValue(&beta, type: .float, index: 21) | |
var M_simd: UInt16 = 32 | |
var N_simd: UInt16 = 32 | |
var K_simd: UInt16 = 8 | |
var M_splits: UInt16 = 1 | |
var N_splits: UInt16 = 1 | |
constants.setConstantValue(&M_simd, type: .ushort, index: 200) | |
constants.setConstantValue(&N_simd, type: .ushort, index: 201) | |
constants.setConstantValue(&K_simd, type: .ushort, index: 202) | |
constants.setConstantValue(&M_splits, type: .ushort, index: 210) | |
constants.setConstantValue(&N_splits, type: .ushort, index: 211) | |
let function = try! library.makeFunction( | |
name: "sgemm", constantValues: constants) | |
let pipeline = try! context.device | |
.makeComputePipelineState(function: function) | |
// Allocate threadgroup blocks. | |
let M_group = M_simd * M_splits | |
let N_group = N_simd * N_splits | |
let A_block_length = M_group * K_simd | |
let B_block_length = K_simd * N_group | |
var blockElements = A_block_length + B_block_length; | |
if (M % 8 != 0) && (N % 8 != 0) { | |
let C_block_length = M_group * N_group; | |
blockElements = max(C_block_length, blockElements) | |
} | |
let blockBytes = blockElements * UInt16(4) | |
func ceilDivide(target: Int, granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(target: N, granularity: N_group), | |
height: ceilDivide(target: M, granularity: M_group), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: Int(32 * M_splits * N_splits), | |
height: 1, | |
depth: 1) | |
// Create the buffers. | |
func createBuffer(data: [Float]) -> MTLBuffer { | |
// Allocate enough memory to store everything in Float32. | |
let bufferSize = problemSize * problemSize * 4 | |
let buffer = context.device.makeBuffer(length: bufferSize)! | |
// Copy the data into the buffer. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
pointer.initialize(from: data, count: problemSize * problemSize) | |
// Return the buffer object. | |
return buffer | |
} | |
let bufferA = createBuffer(data: A) | |
let bufferB = createBuffer(data: B) | |
let bufferC = createBuffer(data: C) | |
// Profile the latency of matrix multiplication. | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Encode the GPU command. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setThreadgroupMemoryLength(Int(blockBytes), index: 0) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
for _ in 0..<duplicatedCommandCount { | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize) | |
} | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
print(latencyMicroseconds, "μs", gflops, "GFLOPS") | |
} | |
// Copy the results to C. | |
do { | |
let rawPointer = bufferC.contents() | |
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self) | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = castedPointer[address] | |
C[address] = entry | |
} | |
} | |
} | |
} | |
#endif | |
// Check the results. | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
// Find the source row IDs. | |
let leftRowID = (m + problemSize - 1) % problemSize | |
let centerRowID = m | |
let rightRowID = (m + problemSize + 1) % problemSize | |
// Find the source values. | |
let leftSource = B[leftRowID * problemSize + n] | |
let centerSource = B[centerRowID * problemSize + n] | |
let rightSource = B[rightRowID * problemSize + n] | |
// Find the expected and actual values. | |
let expected = leftSource - 2 * centerSource + rightSource | |
let actual = C[m * problemSize + n] | |
// Report the results. | |
let error = (expected - actual).magnitude | |
if error > 1e-5 { | |
print("error: \(error) / ~1.000") | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment