Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active May 31, 2024 17:29
Show Gist options
  • Save philipturner/3bda14e876a635e73745c42f2eb240c8 to your computer and use it in GitHub Desktop.
Save philipturner/3bda14e876a635e73745c42f2eb240c8 to your computer and use it in GitHub Desktop.
Investigation of Float32 performance before and after dynamic caching on Apple GPUs
//
// Workspace.swift
// M4DeviceTesting
//
// Created by Philip Turner on 5/24/24.
//
import Metal
// Investigation of Float32 performance before and after dynamic caching on
// Apple GPUs.
//
// ========================================================================== //
// Results (Raw Data)
// ========================================================================== //
//
// AMX, Float32
//
// M1 Max
// - problemSize = 64 | 419 GFLOPS
// - problemSize = 96 | 589 GFLOPS
// - problemSize = 128 | 724 GFLOPS
// - problemSize = 192 | 990 GFLOPS
// - problemSize = 256 | 1032 GFLOPS
// - problemSize = 384 | 1196 GFLOPS
// - problemSize = 512 | 1673 GFLOPS
// - problemSize = 640 | 2261 GFLOPS
// - problemSize = 768 | 2454 GFLOPS
// - problemSize = 896 | 2348 GFLOPS
// - problemSize = 1024 | 1516 GFLOPS
// - problemSize = 1152 | 2267 GFLOPS
// - problemSize = 1280 | 2016 GFLOPS
//
// M4
// - problemSize = 64 | 524 GFLOPS
// - problemSize = 96 | 832 GFLOPS
// - problemSize = 128 | 1059 GFLOPS
// - problemSize = 192 | 1364 GFLOPS
// - problemSize = 256 | 1522 GFLOPS
// - problemSize = 384 | 1693 GFLOPS
// - problemSize = 512 | 1518 GFLOPS
// - problemSize = 640 | 1834 GFLOPS
// - problemSize = 768 | 1843 GFLOPS
// - problemSize = 896 | 1866 GFLOPS
// - problemSize = 1024 | 1578 GFLOPS
// - problemSize = 1152 | 1868 GFLOPS
// - problemSize = 1280 | 1880 GFLOPS
//
// MPS, Float32
//
// M1 Max
// - problemSize = 64 | 23 GFLOPS
// - problemSize = 96 | 38 GFLOPS
// - problemSize = 128 | 113 GFLOPS
// - problemSize = 192 | 267 GFLOPS
// - problemSize = 256 | 482 GFLOPS
// - problemSize = 384 | 989 GFLOPS
// - problemSize = 512 | 2223 GFLOPS
// - problemSize = 640 | 4524 GFLOPS
// - problemSize = 768 | 5296 GFLOPS
// - problemSize = 896 | 6768 GFLOPS
// - problemSize = 1024 | 8045 GFLOPS
// - problemSize = 1152 | 7338 GFLOPS
// - problemSize = 1280 | 7466 GFLOPS
//
// M4
// - problemSize = 64 | 40 GFLOPS
// - problemSize = 96 | 99 GFLOPS
// - problemSize = 128 | 203 GFLOPS
// - problemSize = 192 | 371 GFLOPS
// - problemSize = 256 | 456 GFLOPS
// - problemSize = 384 | 581 GFLOPS
// - problemSize = 512 | 1828 GFLOPS
// - problemSize = 640 | 2146 GFLOPS
// - problemSize = 768 | 2447 GFLOPS
// - problemSize = 896 | 3091 GFLOPS
// - problemSize = 1024 | 3123 GFLOPS
// - problemSize = 1152 | 3150 GFLOPS
// - problemSize = 1280 | 3130 GFLOPS
//
// MFA, Float32
//
// M1 Max
// - problemSize = 64 | 44 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 96 | 109 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 128 | 214 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 192 | 495 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 256 | 883 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 384 | 2912 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 512 | 4091 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 640 | 6440 GFLOPS (32x32x32, 128 threads/group)
// - problemSize = 768 | 7017 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 896 | 7136 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1024 | 6966 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1152 | 8144 GFLOPS (48x48x24, 128 threads/group)
// - problemSize = 1280 | 7813 GFLOPS (48x48x24, 128 threads/group)
//
// M4
// - problemSize = 64 | 39 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 96 | 32 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 128 | 94 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 192 | 364 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 256 | 654 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 384 | 1270 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 512 | 1626 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 640 | 1947 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 768 | 1955 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 896 | 2034 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1024 | 2078 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1152 | 2119 GFLOPS (32x32x8, 32 threads/group)
// - problemSize = 1280 | 2129 GFLOPS (32x32x8, 32 threads/group)
//
// MFA without async copy, Float32
// - All configurations have 32 threads/group.
// - This gives a misleading overestimate of M1 performance. Without async
// copies, performance can be highly inconsistent. Especially with smaller
// matrix sizes and odd/prime matrix sizes.
//
// M1 Max
// - problemSize = 64 | 28 GFLOPS (32x32x8)
// - problemSize = 96 | 70 GFLOPS (32x32x8)
// - problemSize = 128 | 129 GFLOPS (32x32x8)
// - problemSize = 192 | 745 GFLOPS (32x32x8)
// - problemSize = 256 | 1369 GFLOPS (32x32x8)
// - problemSize = 384 | 2907 GFLOPS (32x32x8)
// - problemSize = 512 | 5229 GFLOPS (32x32x8)
// - problemSize = 640 | 5519 GFLOPS (32x32x8)
// - problemSize = 768 | 6170 GFLOPS (48x48x8)
// - problemSize = 896 | 6754 GFLOPS (48x48x8)
// - problemSize = 1024 | 7523 GFLOPS (48x48x8)
// - problemSize = 1152 | 7682 GFLOPS (48x48x8)
// - problemSize = 1280 | 7755 GFLOPS (48x48x8)
//
// Galactic arithmetic intensities, too large for general use:
// - problemSize = 2048 | 8819 GFLOPS (48x48x8)
// - problemSize = 3072 | 8909 GFLOPS (48x48x8)
//
// M4
// - problemSize = 64 | 35 GFLOPS (32x32x8)
// - problemSize = 96 | 84 GFLOPS (32x32x8)
// - problemSize = 128 | 154 GFLOPS (32x32x8)
// - problemSize = 192 | 352 GFLOPS (32x32x8)
// - problemSize = 256 | 577 GFLOPS (32x32x8)
// - problemSize = 384 | 1718 GFLOPS (32x32x8)
// - problemSize = 512 | 2440 GFLOPS (32x32x8)
// - problemSize = 640 | 3013 GFLOPS (32x32x8)
// - problemSize = 768 | 3078 GFLOPS (32x32x8)
// - problemSize = 896 | 3074 GFLOPS (32x32x8)
// - problemSize = 1024 | 3104 GFLOPS (32x32x8)
// - problemSize = 1152 | 3164 GFLOPS (32x32x8)
// - problemSize = 1280 | 3168 GFLOPS (32x32x8)
//
// Fine-tuning the place where the "robust" 32x32 size underperforms MPS:
// - problemSize = 128 | 642 GFLOPS (24x24x8)
// - problemSize = 128 | 842 GFLOPS (16x16x8)
//
// ========================================================================== //
// Discussion
// ========================================================================== //
//
// For reasonable panel sizes in linear algebra (latency-bound, formally O(n^3)
// matrix factorizations), where arithmetic intensity ≤100 operations per
// scalar read from memory.
//
// M1 Max:
// - AMX: 724 GFLOPS
// - MFA: 214 GFLOPS
// - MFA: 129 GFLOPS (without async copy)
// - MPS: 113 GFLOPS
//
// M4:
// - AMX: 1059 GFLOPS
// - MFA: 842 GFLOPS (without async copy, fine-tuned)
// - MFA: 154 GFLOPS (without async copy)
// - MPS: 99 GFLOPS
// - MFA: 94 GFLOPS
//
// For reasonable matrices in AI, where arithmetic intensity ≤1000 operations
// per scalar read from memory.
//
// M1 Max:
// - MFA: 8144 GFLOPS
// - MPS: 8045 GFLOPS
// - MFA: 7755 GFLOPS (without async copy)
// - AMX: 2454 GFLOPS
//
// M4:
// - MFA: 3168 GFLOPS (without async copy)
// - MPS: 3150 GFLOPS
// - MFA: 2129 GFLOPS
// - AMX: 1880 GFLOPS
//
// For galactic matrices (every dimension of the matrix multiplication >>1000,
// and typically square), which are overemphasized in benchmarks.
//
// M1 Max:
// - MFA: 8909 GFLOPS (without async copy)
// - MPS: 8480 GFLOPS
func runApplication() {
print("Hello, console.")
let problemSize: Int = 128
var A = [Float](repeating: .zero, count: problemSize * problemSize)
var B = [Float](repeating: .zero, count: problemSize * problemSize)
var C = [Float](repeating: .zero, count: problemSize * problemSize)
// Initialize A as the 2nd-order periodic Laplacian.
for diagonalID in 0..<problemSize {
let diagonalAddress = diagonalID * problemSize + diagonalID
A[diagonalAddress] = -2
let leftColumnID = (diagonalID + problemSize - 1) % problemSize
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID
A[leftSubDiagonalAddress] = 1
let rightColumnID = (diagonalID + problemSize + 1) % problemSize
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID
A[rightSubDiagonalAddress] = 1
}
// Initialize B to random numbers.
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = Float.random(in: 0..<1)
B[address] = entry
}
}
// Multiply A with B.
#if false
for m in 0..<problemSize {
for n in 0..<problemSize {
var dotProduct: Float = .zero
for k in 0..<problemSize {
// Read from the input matrices.
let lhsAddress = m * problemSize + k
let rhsAddress = k * problemSize + n
let lhsValue = A[lhsAddress]
let rhsValue = B[rhsAddress]
dotProduct += lhsValue * rhsValue
}
// Write to the output matrix.
let address = m * problemSize + n
C[address] = dotProduct
}
}
#else
do {
// Initialize the context.
let context = MTLContext()
let library = try! context.device.makeLibrary(source: GEMM, options: nil)
// Set the function constants.
let constants = MTLFunctionConstantValues()
var M: Int = problemSize
var N: Int = problemSize
var K: Int = problemSize
var transpose: Bool = false
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&transpose, type: .bool, index: 10)
constants.setConstantValue(&transpose, type: .bool, index: 11)
var M_simd: UInt16 = 32
var N_simd: UInt16 = 32
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
let function = try! library.makeFunction(
name: "sgemm", constantValues: constants)
let pipeline = try! context.device
.makeComputePipelineState(function: function)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
let gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_simd),
height: ceilDivide(target: M, granularity: M_simd),
depth: 1)
let groupSize = MTLSize(
width: 32,
height: 1,
depth: 1)
// Create the buffers.
func createBuffer(data: [Float]) -> MTLBuffer {
// Allocate enough memory to store everything in Float32.
let bufferSize = problemSize * problemSize * 4
let buffer = context.device.makeBuffer(length: bufferSize)!
// Copy the data into the buffer.
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
pointer.initialize(from: data, count: problemSize * problemSize)
// Return the buffer object.
return buffer
}
let bufferA = createBuffer(data: A)
let bufferB = createBuffer(data: B)
let bufferC = createBuffer(data: C)
// Profile the latency of matrix multiplication.
for _ in 0..<15 {
let duplicatedCommandCount: Int = 20
// Execute the operation.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
for _ in 0..<duplicatedCommandCount {
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
print(latencyMicroseconds, "μs", gflops, "GFLOPS")
}
// Copy the results to C.
do {
let rawPointer = bufferC.contents()
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self)
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = castedPointer[address]
C[address] = entry
}
}
}
}
#endif
// Check the results.
for m in 0..<problemSize {
for n in 0..<problemSize {
// Find the source row IDs.
let leftRowID = (m + problemSize - 1) % problemSize
let centerRowID = m
let rightRowID = (m + problemSize + 1) % problemSize
// Find the source values.
let leftSource = B[leftRowID * problemSize + n]
let centerSource = B[centerRowID * problemSize + n]
let rightSource = B[rightRowID * problemSize + n]
// Find the expected and actual values.
let expected = leftSource - 2 * centerSource + rightSource
let actual = C[m * problemSize + n]
// Report the results.
let error = (expected - actual).magnitude
if error > 1e-5 {
print("error: \(error) / ~1.000")
}
}
}
}
// MARK: - Utilities
#if false
// A configuration for a general matrix-matrix multiplication.
struct GEMMDescriptor {
// The M, N, and K arguments.
var dimension: SIMD3<Int>?
// The A, LDA, and TRANSA arguments.
var leftOperand: UnsafePointer<Float>?
var leftOperandStride: Int?
var leftTransposeState: Character = "N"
// The B, LDB, and TRANSB arguments.
var rightOperand: UnsafePointer<Float>?
var rightOperandStride: Int?
var rightTransposeState: Character = "N"
// The C and LDC arguments.
var accumulator: UnsafeMutablePointer<Float>?
var accumulatorStride: Int?
// The 'alpha' and 'beta' arguments.
var productScale: Float = 1
var accumulatorScale: Float = 0
@_transparent
init() { }
}
// Descriptor-based wrapper over `sgemm_` from Accelerate.
struct GEMM {
// In typical API usage, one does not access the object's properties.
@discardableResult
@_transparent
init(descriptor: GEMMDescriptor) {
guard let dimension = descriptor.dimension,
let leftOperand = descriptor.leftOperand,
let leftOperandStride = descriptor.leftOperandStride,
let rightOperand = descriptor.rightOperand,
let rightOperandStride = descriptor.rightOperandStride,
let accumulator = descriptor.accumulator,
let accumulatorStride = descriptor.accumulatorStride else {
fatalError("Descriptor not complete.")
}
var TRANSA = CChar(descriptor.leftTransposeState.asciiValue!)
var TRANSB = CChar(descriptor.rightTransposeState.asciiValue!)
var M = Int32(truncatingIfNeeded: dimension[0])
var N = Int32(truncatingIfNeeded: dimension[1])
var K = Int32(truncatingIfNeeded: dimension[2])
var ALPHA = descriptor.productScale
var LDA = Int32(truncatingIfNeeded: leftOperandStride)
var BETA = descriptor.accumulatorScale
var LDB = Int32(truncatingIfNeeded: rightOperandStride)
var LDC = Int32(truncatingIfNeeded: accumulatorStride)
sgemm_(
&TRANSA,
&TRANSB,
&M,
&N,
&K,
&ALPHA,
leftOperand, &LDA,
rightOperand, &LDB,
&BETA,
accumulator, &LDC)
}
}
// Reference code for benchmarking CPU performance.
func testCPUPerformance() {
A.withUnsafeBufferPointer {
let A = $0.baseAddress!
B.withUnsafeBufferPointer {
let B = $0.baseAddress!
C.withUnsafeMutableBufferPointer {
let C = $0.baseAddress!
// Since LAPACK uses column major order, we must set up a transposed
// matrix multiplication.
// C = A B -> C^T = B^T A^T
//
// However, LAPACK also requires column major for the inputs.
// B^T -> B
// A^T -> A
// C^T = B A
var gemmDesc = GEMMDescriptor()
gemmDesc.dimension = SIMD3(repeating: problemSize)
gemmDesc.leftOperand = B
gemmDesc.leftOperandStride = problemSize
gemmDesc.rightOperand = A
gemmDesc.rightOperandStride = problemSize
gemmDesc.accumulator = C
gemmDesc.accumulatorStride = problemSize
// Profile the latency of the matrix multiplication.
for _ in 0..<10 {
let start = CACurrentMediaTime()
GEMM(descriptor: gemmDesc)
let end = CACurrentMediaTime()
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
let operations = 2 * problemSize * problemSize * problemSize
let gflops = Int(Double(operations) / Double(latency) / 1e9)
print(latencyMicroseconds, "μs", gflops, "GFLOPS")
}
}
}
}
}
#endif
struct MTLContext {
var device: MTLDevice
var commandQueue: MTLCommandQueue
init() {
device = MTLCreateSystemDefaultDevice()!
commandQueue = device.makeCommandQueue()!
}
}
#if false
struct MPSMatrixStorageDescriptor {
var context: MTLContext?
var data: [Float]?
var problemSize: Int?
}
struct MPSMatrixStorage {
var buffer: MTLBuffer
var matrix: MPSMatrix
init(descriptor: MPSMatrixStorageDescriptor) {
guard let context = descriptor.context,
let data = descriptor.data,
let problemSize = descriptor.problemSize else {
fatalError("Descriptor was invalid.")
}
// Allocate enough memory to store everything in Float32.
let bufferSize = problemSize * problemSize * 4
buffer = context.device.makeBuffer(length: bufferSize)!
// Copy the data into the buffer.
let pointer = buffer.contents().assumingMemoryBound(to: Float.self)
pointer.initialize(from: data, count: problemSize * problemSize)
// Set the descriptor properties.
let matrixDesc = MPSMatrixDescriptor(
rows: problemSize,
columns: problemSize,
rowBytes: problemSize * 4, // Float32
dataType: .float32)
// Initialize the matrix object.
matrix = MPSMatrix(buffer: buffer, descriptor: matrixDesc)
}
}
// Reference code for benchmarking MPS performance.
func testMPSPerformance() {
// Initialize the context.
let context = MTLContext()
// Initialize the matrices.
var matrixStorageDesc = MPSMatrixStorageDescriptor()
matrixStorageDesc.context = context
matrixStorageDesc.problemSize = problemSize
matrixStorageDesc.data = A
let matrixA = MPSMatrixStorage(descriptor: matrixStorageDesc)
matrixStorageDesc.data = B
let matrixB = MPSMatrixStorage(descriptor: matrixStorageDesc)
matrixStorageDesc.data = C
let matrixC = MPSMatrixStorage(descriptor: matrixStorageDesc)
// Initialize the multiplication object.
let multiplication = MPSMatrixMultiplication(
device: context.device,
resultRows: problemSize,
resultColumns: problemSize,
interiorColumns: problemSize)
multiplication.leftMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
multiplication.rightMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
multiplication.resultMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0)
// Profile the latency of the matrix multiplication.
for _ in 0..<10 {
let duplicatedCommandCount: Int = 10
// Execute the operation.
let commandBuffer = context.commandQueue.makeCommandBuffer()!
for _ in 0..<duplicatedCommandCount {
multiplication.encode(
commandBuffer: commandBuffer,
leftMatrix: matrixA.matrix,
rightMatrix: matrixB.matrix,
resultMatrix: matrixC.matrix)
}
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Determine the time taken.
let start = commandBuffer.gpuStartTime
let end = commandBuffer.gpuEndTime
let latency = end - start
let latencyMicroseconds = Int(latency / 1e-6)
// Determine the amount of work done.
var operations = 2 * problemSize * problemSize * problemSize
operations = operations * duplicatedCommandCount
let gflops = Int(Double(operations) / Double(latency) / 1e9)
// Report the results.
print(latencyMicroseconds, "μs", gflops, "GFLOPS")
}
// Copy the results to C.
do {
let rawPointer = matrixC.buffer.contents()
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self)
for rowID in 0..<problemSize {
for columnID in 0..<problemSize {
let address = rowID * problemSize + columnID
let entry = castedPointer[address]
C[address] = entry
}
}
}
}
#endif
let metalSimdgroupMatrixStorage: String = """
// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2023 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE
// Contains C++ symbols accessible to a developer through automatic code
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
// Library for consistency with other Metal code.
#if defined(__HAVE_SIMDGROUP_MATRIX__)
#pragma METAL internals : enable
namespace metal
{
template <typename T>
struct simdgroup_matrix_storage {
typedef vec<T, 64> storage_type;
storage_type t;
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
return reinterpret_cast<thread vec<T, 2>*>(&t);
}
METAL_FUNC simdgroup_matrix_storage() thread = default;
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
*(this->thread_elements()) = thread_elements;
}
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
// https://patents.google.com/patent/US11256518B2
ushort lane_id = thread_index_in_simdgroup;
ushort quad_id = lane_id / 4;
constexpr ushort QUADRANT_SPAN_M = 4;
constexpr ushort THREADS_PER_QUADRANT = 8;
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
return ushort2(N_in_simd, M_in_simd);
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
// WARNING: All load and store functions assume the X dimension is divisible by 2.
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
}
}
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
}
}
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
}
}
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
}
}
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
template <typename U, typename V>
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
if (!accumulate) {
*(thread_elements()) = vec<T, 2>(0);
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
};
} // namespace metal
#pragma METAL internals : disable
#endif
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
"""
let GEMM = """
//
// GEMM.metal
// MetalFlashAttention
//
// Created by Philip Turner on 6/23/23.
//
#include <metal_stdlib>
\(metalSimdgroupMatrixStorage)
using namespace metal;
// MARK: - Function Constants
// Dimensions of each matrix.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans [[function_constant(10)]];
constant bool B_trans [[function_constant(11)]];
constant uint A_leading_dim = A_trans ? M : K;
constant uint B_leading_dim = B_trans ? K : N;
constant ushort M_simd [[function_constant(200)]];
constant ushort N_simd [[function_constant(201)]];
// Elide work on the edge when matrix dimension < SRAM block dimension.
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort A_sram_length = M_simd / 8;
constant ushort B_sram_length = N_simd / 8;
constant ushort A_sram_offset = 0;
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
// MARK: - Utilities
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[M_simd][8]
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// A_sram[8][N_simd]
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin
) {
// C_sram[M_simd][N_simd]
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
device T *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start; m < m_end; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store(C, N, origin);
}
}
}
// MARK: - Kernels
kernel void sgemm(device float *A [[buffer(0)]],
device float *B [[buffer(1)]],
device float *C [[buffer(2)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort lane_id [[thread_index_in_simdgroup]])
{
simdgroup_matrix_storage<float> sram[1024];
ushort2 offset_in_simd = simdgroup_matrix_storage<float>::offset(lane_id);
ushort2 A_offset(0, gid.y * M_simd);
ushort2 B_offset(gid.x * N_simd, 0);
A_offset += offset_in_simd;
B_offset += offset_in_simd;
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<float>(0);
}
}
for (uint k = 0; k < K; k += 8) {
auto A_src = simdgroup_matrix_storage<float>::apply_offset(
A, A_leading_dim, uint2(A_offset), A_trans);
auto B_src = simdgroup_matrix_storage<float>::apply_offset(
B, B_leading_dim, uint2(B_offset), B_trans);
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
ushort2 origin(0, m);
A_sram(sram, origin)->load(A_src, A_leading_dim, origin, A_trans);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, 0);
B_sram(sram, origin)->load(B_src, B_leading_dim, origin, B_trans);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto A = A_sram(sram, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, ushort2(n, m));
C->multiply(*A, *B);
}
}
A_offset.x += 8;
B_offset.y += 8;
}
// WARNING: M and N must be divisible by 8.
{
uint2 matrix_origin = uint2(B_offset.x, A_offset.y);
auto C_src = simdgroup_matrix_storage<float>::apply_offset(
C, N, matrix_origin);
store_accumulator(sram, C_src, false, false);
const uint M_edge_floor = M - M % M_simd;
const uint N_edge_floor = N - N % N_simd;
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, false);
}
if (matrix_origin.x < N_edge_floor) {
store_accumulator(sram, C_src, false, true);
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, true);
}
}
}
}
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment