Last active
May 31, 2024 17:29
-
-
Save philipturner/3bda14e876a635e73745c42f2eb240c8 to your computer and use it in GitHub Desktop.
Investigation of Float32 performance before and after dynamic caching on Apple GPUs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Workspace.swift | |
// M4DeviceTesting | |
// | |
// Created by Philip Turner on 5/24/24. | |
// | |
import Metal | |
// Investigation of Float32 performance before and after dynamic caching on | |
// Apple GPUs. | |
// | |
// ========================================================================== // | |
// Results (Raw Data) | |
// ========================================================================== // | |
// | |
// AMX, Float32 | |
// | |
// M1 Max | |
// - problemSize = 64 | 419 GFLOPS | |
// - problemSize = 96 | 589 GFLOPS | |
// - problemSize = 128 | 724 GFLOPS | |
// - problemSize = 192 | 990 GFLOPS | |
// - problemSize = 256 | 1032 GFLOPS | |
// - problemSize = 384 | 1196 GFLOPS | |
// - problemSize = 512 | 1673 GFLOPS | |
// - problemSize = 640 | 2261 GFLOPS | |
// - problemSize = 768 | 2454 GFLOPS | |
// - problemSize = 896 | 2348 GFLOPS | |
// - problemSize = 1024 | 1516 GFLOPS | |
// - problemSize = 1152 | 2267 GFLOPS | |
// - problemSize = 1280 | 2016 GFLOPS | |
// | |
// M4 | |
// - problemSize = 64 | 524 GFLOPS | |
// - problemSize = 96 | 832 GFLOPS | |
// - problemSize = 128 | 1059 GFLOPS | |
// - problemSize = 192 | 1364 GFLOPS | |
// - problemSize = 256 | 1522 GFLOPS | |
// - problemSize = 384 | 1693 GFLOPS | |
// - problemSize = 512 | 1518 GFLOPS | |
// - problemSize = 640 | 1834 GFLOPS | |
// - problemSize = 768 | 1843 GFLOPS | |
// - problemSize = 896 | 1866 GFLOPS | |
// - problemSize = 1024 | 1578 GFLOPS | |
// - problemSize = 1152 | 1868 GFLOPS | |
// - problemSize = 1280 | 1880 GFLOPS | |
// | |
// MPS, Float32 | |
// | |
// M1 Max | |
// - problemSize = 64 | 23 GFLOPS | |
// - problemSize = 96 | 38 GFLOPS | |
// - problemSize = 128 | 113 GFLOPS | |
// - problemSize = 192 | 267 GFLOPS | |
// - problemSize = 256 | 482 GFLOPS | |
// - problemSize = 384 | 989 GFLOPS | |
// - problemSize = 512 | 2223 GFLOPS | |
// - problemSize = 640 | 4524 GFLOPS | |
// - problemSize = 768 | 5296 GFLOPS | |
// - problemSize = 896 | 6768 GFLOPS | |
// - problemSize = 1024 | 8045 GFLOPS | |
// - problemSize = 1152 | 7338 GFLOPS | |
// - problemSize = 1280 | 7466 GFLOPS | |
// | |
// M4 | |
// - problemSize = 64 | 40 GFLOPS | |
// - problemSize = 96 | 99 GFLOPS | |
// - problemSize = 128 | 203 GFLOPS | |
// - problemSize = 192 | 371 GFLOPS | |
// - problemSize = 256 | 456 GFLOPS | |
// - problemSize = 384 | 581 GFLOPS | |
// - problemSize = 512 | 1828 GFLOPS | |
// - problemSize = 640 | 2146 GFLOPS | |
// - problemSize = 768 | 2447 GFLOPS | |
// - problemSize = 896 | 3091 GFLOPS | |
// - problemSize = 1024 | 3123 GFLOPS | |
// - problemSize = 1152 | 3150 GFLOPS | |
// - problemSize = 1280 | 3130 GFLOPS | |
// | |
// MFA, Float32 | |
// | |
// M1 Max | |
// - problemSize = 64 | 44 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 96 | 109 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 128 | 214 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 192 | 495 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 256 | 883 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 384 | 2912 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 512 | 4091 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 640 | 6440 GFLOPS (32x32x32, 128 threads/group) | |
// - problemSize = 768 | 7017 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 896 | 7136 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1024 | 6966 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1152 | 8144 GFLOPS (48x48x24, 128 threads/group) | |
// - problemSize = 1280 | 7813 GFLOPS (48x48x24, 128 threads/group) | |
// | |
// M4 | |
// - problemSize = 64 | 39 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 96 | 32 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 128 | 94 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 192 | 364 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 256 | 654 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 384 | 1270 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 512 | 1626 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 640 | 1947 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 768 | 1955 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 896 | 2034 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1024 | 2078 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1152 | 2119 GFLOPS (32x32x8, 32 threads/group) | |
// - problemSize = 1280 | 2129 GFLOPS (32x32x8, 32 threads/group) | |
// | |
// MFA without async copy, Float32 | |
// - All configurations have 32 threads/group. | |
// - This gives a misleading overestimate of M1 performance. Without async | |
// copies, performance can be highly inconsistent. Especially with smaller | |
// matrix sizes and odd/prime matrix sizes. | |
// | |
// M1 Max | |
// - problemSize = 64 | 28 GFLOPS (32x32x8) | |
// - problemSize = 96 | 70 GFLOPS (32x32x8) | |
// - problemSize = 128 | 129 GFLOPS (32x32x8) | |
// - problemSize = 192 | 745 GFLOPS (32x32x8) | |
// - problemSize = 256 | 1369 GFLOPS (32x32x8) | |
// - problemSize = 384 | 2907 GFLOPS (32x32x8) | |
// - problemSize = 512 | 5229 GFLOPS (32x32x8) | |
// - problemSize = 640 | 5519 GFLOPS (32x32x8) | |
// - problemSize = 768 | 6170 GFLOPS (48x48x8) | |
// - problemSize = 896 | 6754 GFLOPS (48x48x8) | |
// - problemSize = 1024 | 7523 GFLOPS (48x48x8) | |
// - problemSize = 1152 | 7682 GFLOPS (48x48x8) | |
// - problemSize = 1280 | 7755 GFLOPS (48x48x8) | |
// | |
// Galactic arithmetic intensities, too large for general use: | |
// - problemSize = 2048 | 8819 GFLOPS (48x48x8) | |
// - problemSize = 3072 | 8909 GFLOPS (48x48x8) | |
// | |
// M4 | |
// - problemSize = 64 | 35 GFLOPS (32x32x8) | |
// - problemSize = 96 | 84 GFLOPS (32x32x8) | |
// - problemSize = 128 | 154 GFLOPS (32x32x8) | |
// - problemSize = 192 | 352 GFLOPS (32x32x8) | |
// - problemSize = 256 | 577 GFLOPS (32x32x8) | |
// - problemSize = 384 | 1718 GFLOPS (32x32x8) | |
// - problemSize = 512 | 2440 GFLOPS (32x32x8) | |
// - problemSize = 640 | 3013 GFLOPS (32x32x8) | |
// - problemSize = 768 | 3078 GFLOPS (32x32x8) | |
// - problemSize = 896 | 3074 GFLOPS (32x32x8) | |
// - problemSize = 1024 | 3104 GFLOPS (32x32x8) | |
// - problemSize = 1152 | 3164 GFLOPS (32x32x8) | |
// - problemSize = 1280 | 3168 GFLOPS (32x32x8) | |
// | |
// Fine-tuning the place where the "robust" 32x32 size underperforms MPS: | |
// - problemSize = 128 | 642 GFLOPS (24x24x8) | |
// - problemSize = 128 | 842 GFLOPS (16x16x8) | |
// | |
// ========================================================================== // | |
// Discussion | |
// ========================================================================== // | |
// | |
// For reasonable panel sizes in linear algebra (latency-bound, formally O(n^3) | |
// matrix factorizations), where arithmetic intensity ≤100 operations per | |
// scalar read from memory. | |
// | |
// M1 Max: | |
// - AMX: 724 GFLOPS | |
// - MFA: 214 GFLOPS | |
// - MFA: 129 GFLOPS (without async copy) | |
// - MPS: 113 GFLOPS | |
// | |
// M4: | |
// - AMX: 1059 GFLOPS | |
// - MFA: 842 GFLOPS (without async copy, fine-tuned) | |
// - MFA: 154 GFLOPS (without async copy) | |
// - MPS: 99 GFLOPS | |
// - MFA: 94 GFLOPS | |
// | |
// For reasonable matrices in AI, where arithmetic intensity ≤1000 operations | |
// per scalar read from memory. | |
// | |
// M1 Max: | |
// - MFA: 8144 GFLOPS | |
// - MPS: 8045 GFLOPS | |
// - MFA: 7755 GFLOPS (without async copy) | |
// - AMX: 2454 GFLOPS | |
// | |
// M4: | |
// - MFA: 3168 GFLOPS (without async copy) | |
// - MPS: 3150 GFLOPS | |
// - MFA: 2129 GFLOPS | |
// - AMX: 1880 GFLOPS | |
// | |
// For galactic matrices (every dimension of the matrix multiplication >>1000, | |
// and typically square), which are overemphasized in benchmarks. | |
// | |
// M1 Max: | |
// - MFA: 8909 GFLOPS (without async copy) | |
// - MPS: 8480 GFLOPS | |
func runApplication() { | |
print("Hello, console.") | |
let problemSize: Int = 128 | |
var A = [Float](repeating: .zero, count: problemSize * problemSize) | |
var B = [Float](repeating: .zero, count: problemSize * problemSize) | |
var C = [Float](repeating: .zero, count: problemSize * problemSize) | |
// Initialize A as the 2nd-order periodic Laplacian. | |
for diagonalID in 0..<problemSize { | |
let diagonalAddress = diagonalID * problemSize + diagonalID | |
A[diagonalAddress] = -2 | |
let leftColumnID = (diagonalID + problemSize - 1) % problemSize | |
let leftSubDiagonalAddress = diagonalID * problemSize + leftColumnID | |
A[leftSubDiagonalAddress] = 1 | |
let rightColumnID = (diagonalID + problemSize + 1) % problemSize | |
let rightSubDiagonalAddress = diagonalID * problemSize + rightColumnID | |
A[rightSubDiagonalAddress] = 1 | |
} | |
// Initialize B to random numbers. | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = Float.random(in: 0..<1) | |
B[address] = entry | |
} | |
} | |
// Multiply A with B. | |
#if false | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
var dotProduct: Float = .zero | |
for k in 0..<problemSize { | |
// Read from the input matrices. | |
let lhsAddress = m * problemSize + k | |
let rhsAddress = k * problemSize + n | |
let lhsValue = A[lhsAddress] | |
let rhsValue = B[rhsAddress] | |
dotProduct += lhsValue * rhsValue | |
} | |
// Write to the output matrix. | |
let address = m * problemSize + n | |
C[address] = dotProduct | |
} | |
} | |
#else | |
do { | |
// Initialize the context. | |
let context = MTLContext() | |
let library = try! context.device.makeLibrary(source: GEMM, options: nil) | |
// Set the function constants. | |
let constants = MTLFunctionConstantValues() | |
var M: Int = problemSize | |
var N: Int = problemSize | |
var K: Int = problemSize | |
var transpose: Bool = false | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
constants.setConstantValue(&transpose, type: .bool, index: 10) | |
constants.setConstantValue(&transpose, type: .bool, index: 11) | |
var M_simd: UInt16 = 32 | |
var N_simd: UInt16 = 32 | |
constants.setConstantValue(&M_simd, type: .ushort, index: 200) | |
constants.setConstantValue(&N_simd, type: .ushort, index: 201) | |
let function = try! library.makeFunction( | |
name: "sgemm", constantValues: constants) | |
let pipeline = try! context.device | |
.makeComputePipelineState(function: function) | |
func ceilDivide(target: Int, granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(target: N, granularity: N_simd), | |
height: ceilDivide(target: M, granularity: M_simd), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: 32, | |
height: 1, | |
depth: 1) | |
// Create the buffers. | |
func createBuffer(data: [Float]) -> MTLBuffer { | |
// Allocate enough memory to store everything in Float32. | |
let bufferSize = problemSize * problemSize * 4 | |
let buffer = context.device.makeBuffer(length: bufferSize)! | |
// Copy the data into the buffer. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
pointer.initialize(from: data, count: problemSize * problemSize) | |
// Return the buffer object. | |
return buffer | |
} | |
let bufferA = createBuffer(data: A) | |
let bufferB = createBuffer(data: B) | |
let bufferC = createBuffer(data: C) | |
// Profile the latency of matrix multiplication. | |
for _ in 0..<15 { | |
let duplicatedCommandCount: Int = 20 | |
// Execute the operation. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
for _ in 0..<duplicatedCommandCount { | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize) | |
} | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
print(latencyMicroseconds, "μs", gflops, "GFLOPS") | |
} | |
// Copy the results to C. | |
do { | |
let rawPointer = bufferC.contents() | |
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self) | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = castedPointer[address] | |
C[address] = entry | |
} | |
} | |
} | |
} | |
#endif | |
// Check the results. | |
for m in 0..<problemSize { | |
for n in 0..<problemSize { | |
// Find the source row IDs. | |
let leftRowID = (m + problemSize - 1) % problemSize | |
let centerRowID = m | |
let rightRowID = (m + problemSize + 1) % problemSize | |
// Find the source values. | |
let leftSource = B[leftRowID * problemSize + n] | |
let centerSource = B[centerRowID * problemSize + n] | |
let rightSource = B[rightRowID * problemSize + n] | |
// Find the expected and actual values. | |
let expected = leftSource - 2 * centerSource + rightSource | |
let actual = C[m * problemSize + n] | |
// Report the results. | |
let error = (expected - actual).magnitude | |
if error > 1e-5 { | |
print("error: \(error) / ~1.000") | |
} | |
} | |
} | |
} | |
// MARK: - Utilities | |
#if false | |
// A configuration for a general matrix-matrix multiplication. | |
struct GEMMDescriptor { | |
// The M, N, and K arguments. | |
var dimension: SIMD3<Int>? | |
// The A, LDA, and TRANSA arguments. | |
var leftOperand: UnsafePointer<Float>? | |
var leftOperandStride: Int? | |
var leftTransposeState: Character = "N" | |
// The B, LDB, and TRANSB arguments. | |
var rightOperand: UnsafePointer<Float>? | |
var rightOperandStride: Int? | |
var rightTransposeState: Character = "N" | |
// The C and LDC arguments. | |
var accumulator: UnsafeMutablePointer<Float>? | |
var accumulatorStride: Int? | |
// The 'alpha' and 'beta' arguments. | |
var productScale: Float = 1 | |
var accumulatorScale: Float = 0 | |
@_transparent | |
init() { } | |
} | |
// Descriptor-based wrapper over `sgemm_` from Accelerate. | |
struct GEMM { | |
// In typical API usage, one does not access the object's properties. | |
@discardableResult | |
@_transparent | |
init(descriptor: GEMMDescriptor) { | |
guard let dimension = descriptor.dimension, | |
let leftOperand = descriptor.leftOperand, | |
let leftOperandStride = descriptor.leftOperandStride, | |
let rightOperand = descriptor.rightOperand, | |
let rightOperandStride = descriptor.rightOperandStride, | |
let accumulator = descriptor.accumulator, | |
let accumulatorStride = descriptor.accumulatorStride else { | |
fatalError("Descriptor not complete.") | |
} | |
var TRANSA = CChar(descriptor.leftTransposeState.asciiValue!) | |
var TRANSB = CChar(descriptor.rightTransposeState.asciiValue!) | |
var M = Int32(truncatingIfNeeded: dimension[0]) | |
var N = Int32(truncatingIfNeeded: dimension[1]) | |
var K = Int32(truncatingIfNeeded: dimension[2]) | |
var ALPHA = descriptor.productScale | |
var LDA = Int32(truncatingIfNeeded: leftOperandStride) | |
var BETA = descriptor.accumulatorScale | |
var LDB = Int32(truncatingIfNeeded: rightOperandStride) | |
var LDC = Int32(truncatingIfNeeded: accumulatorStride) | |
sgemm_( | |
&TRANSA, | |
&TRANSB, | |
&M, | |
&N, | |
&K, | |
&ALPHA, | |
leftOperand, &LDA, | |
rightOperand, &LDB, | |
&BETA, | |
accumulator, &LDC) | |
} | |
} | |
// Reference code for benchmarking CPU performance. | |
func testCPUPerformance() { | |
A.withUnsafeBufferPointer { | |
let A = $0.baseAddress! | |
B.withUnsafeBufferPointer { | |
let B = $0.baseAddress! | |
C.withUnsafeMutableBufferPointer { | |
let C = $0.baseAddress! | |
// Since LAPACK uses column major order, we must set up a transposed | |
// matrix multiplication. | |
// C = A B -> C^T = B^T A^T | |
// | |
// However, LAPACK also requires column major for the inputs. | |
// B^T -> B | |
// A^T -> A | |
// C^T = B A | |
var gemmDesc = GEMMDescriptor() | |
gemmDesc.dimension = SIMD3(repeating: problemSize) | |
gemmDesc.leftOperand = B | |
gemmDesc.leftOperandStride = problemSize | |
gemmDesc.rightOperand = A | |
gemmDesc.rightOperandStride = problemSize | |
gemmDesc.accumulator = C | |
gemmDesc.accumulatorStride = problemSize | |
// Profile the latency of the matrix multiplication. | |
for _ in 0..<10 { | |
let start = CACurrentMediaTime() | |
GEMM(descriptor: gemmDesc) | |
let end = CACurrentMediaTime() | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
let operations = 2 * problemSize * problemSize * problemSize | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
print(latencyMicroseconds, "μs", gflops, "GFLOPS") | |
} | |
} | |
} | |
} | |
} | |
#endif | |
struct MTLContext { | |
var device: MTLDevice | |
var commandQueue: MTLCommandQueue | |
init() { | |
device = MTLCreateSystemDefaultDevice()! | |
commandQueue = device.makeCommandQueue()! | |
} | |
} | |
#if false | |
struct MPSMatrixStorageDescriptor { | |
var context: MTLContext? | |
var data: [Float]? | |
var problemSize: Int? | |
} | |
struct MPSMatrixStorage { | |
var buffer: MTLBuffer | |
var matrix: MPSMatrix | |
init(descriptor: MPSMatrixStorageDescriptor) { | |
guard let context = descriptor.context, | |
let data = descriptor.data, | |
let problemSize = descriptor.problemSize else { | |
fatalError("Descriptor was invalid.") | |
} | |
// Allocate enough memory to store everything in Float32. | |
let bufferSize = problemSize * problemSize * 4 | |
buffer = context.device.makeBuffer(length: bufferSize)! | |
// Copy the data into the buffer. | |
let pointer = buffer.contents().assumingMemoryBound(to: Float.self) | |
pointer.initialize(from: data, count: problemSize * problemSize) | |
// Set the descriptor properties. | |
let matrixDesc = MPSMatrixDescriptor( | |
rows: problemSize, | |
columns: problemSize, | |
rowBytes: problemSize * 4, // Float32 | |
dataType: .float32) | |
// Initialize the matrix object. | |
matrix = MPSMatrix(buffer: buffer, descriptor: matrixDesc) | |
} | |
} | |
// Reference code for benchmarking MPS performance. | |
func testMPSPerformance() { | |
// Initialize the context. | |
let context = MTLContext() | |
// Initialize the matrices. | |
var matrixStorageDesc = MPSMatrixStorageDescriptor() | |
matrixStorageDesc.context = context | |
matrixStorageDesc.problemSize = problemSize | |
matrixStorageDesc.data = A | |
let matrixA = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
matrixStorageDesc.data = B | |
let matrixB = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
matrixStorageDesc.data = C | |
let matrixC = MPSMatrixStorage(descriptor: matrixStorageDesc) | |
// Initialize the multiplication object. | |
let multiplication = MPSMatrixMultiplication( | |
device: context.device, | |
resultRows: problemSize, | |
resultColumns: problemSize, | |
interiorColumns: problemSize) | |
multiplication.leftMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
multiplication.rightMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
multiplication.resultMatrixOrigin = MTLOrigin(x: 0, y: 0, z: 0) | |
// Profile the latency of the matrix multiplication. | |
for _ in 0..<10 { | |
let duplicatedCommandCount: Int = 10 | |
// Execute the operation. | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
for _ in 0..<duplicatedCommandCount { | |
multiplication.encode( | |
commandBuffer: commandBuffer, | |
leftMatrix: matrixA.matrix, | |
rightMatrix: matrixB.matrix, | |
resultMatrix: matrixC.matrix) | |
} | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Determine the time taken. | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let latency = end - start | |
let latencyMicroseconds = Int(latency / 1e-6) | |
// Determine the amount of work done. | |
var operations = 2 * problemSize * problemSize * problemSize | |
operations = operations * duplicatedCommandCount | |
let gflops = Int(Double(operations) / Double(latency) / 1e9) | |
// Report the results. | |
print(latencyMicroseconds, "μs", gflops, "GFLOPS") | |
} | |
// Copy the results to C. | |
do { | |
let rawPointer = matrixC.buffer.contents() | |
let castedPointer = rawPointer.assumingMemoryBound(to: Float.self) | |
for rowID in 0..<problemSize { | |
for columnID in 0..<problemSize { | |
let address = rowID * problemSize + columnID | |
let entry = castedPointer[address] | |
C[address] = entry | |
} | |
} | |
} | |
} | |
#endif | |
let metalSimdgroupMatrixStorage: String = """ | |
// -*- Metal -*- | |
//===-- metal_simdgroup_matrix_storage ------------------------------------===// | |
// Copyright (c) 2023 Philip Turner. See MIT LICENSE | |
//===----------------------------------------------------------------------===// | |
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE | |
#define __METAL_SIMDGROUP_MATRIX_STORAGE | |
// Contains C++ symbols accessible to a developer through automatic code | |
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard | |
// Library for consistency with other Metal code. | |
#if defined(__HAVE_SIMDGROUP_MATRIX__) | |
#pragma METAL internals : enable | |
namespace metal | |
{ | |
template <typename T> | |
struct simdgroup_matrix_storage { | |
typedef vec<T, 64> storage_type; | |
storage_type t; | |
METAL_FUNC thread vec<T, 2>* thread_elements() thread { | |
return reinterpret_cast<thread vec<T, 2>*>(&t); | |
} | |
METAL_FUNC simdgroup_matrix_storage() thread = default; | |
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread { | |
*(this->thread_elements()) = thread_elements; | |
} | |
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) { | |
// https://patents.google.com/patent/US11256518B2 | |
ushort lane_id = thread_index_in_simdgroup; | |
ushort quad_id = lane_id / 4; | |
constexpr ushort QUADRANT_SPAN_M = 4; | |
constexpr ushort THREADS_PER_QUADRANT = 8; | |
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; | |
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); | |
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; | |
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 | |
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 | |
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; | |
return ushort2(N_in_simd, M_in_simd); | |
} | |
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; | |
} else { | |
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; | |
} | |
} | |
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
return src + matrix_origin.x * elements_per_row + matrix_origin.y; | |
} else { | |
return src + matrix_origin.y * elements_per_row + matrix_origin.x; | |
} | |
} | |
// WARNING: All load and store functions assume the X dimension is divisible by 2. | |
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]); | |
} else { | |
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x); | |
} | |
} | |
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; | |
} else { | |
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; | |
} | |
} | |
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0]; | |
} | |
} | |
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1]; | |
} | |
} | |
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { | |
if (transpose_matrix) { | |
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0]; | |
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1]; | |
} else { | |
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); | |
} | |
} | |
template <typename U, typename V> | |
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) { | |
if (!accumulate) { | |
*(thread_elements()) = vec<T, 2>(0); | |
} | |
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type()); | |
} | |
}; | |
} // namespace metal | |
#pragma METAL internals : disable | |
#endif | |
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE | |
""" | |
let GEMM = """ | |
// | |
// GEMM.metal | |
// MetalFlashAttention | |
// | |
// Created by Philip Turner on 6/23/23. | |
// | |
#include <metal_stdlib> | |
\(metalSimdgroupMatrixStorage) | |
using namespace metal; | |
// MARK: - Function Constants | |
// Dimensions of each matrix. | |
constant uint M [[function_constant(0)]]; | |
constant uint N [[function_constant(1)]]; | |
constant uint K [[function_constant(2)]]; | |
// Whether each matrix is transposed. | |
constant bool A_trans [[function_constant(10)]]; | |
constant bool B_trans [[function_constant(11)]]; | |
constant uint A_leading_dim = A_trans ? M : K; | |
constant uint B_leading_dim = B_trans ? K : N; | |
constant ushort M_simd [[function_constant(200)]]; | |
constant ushort N_simd [[function_constant(201)]]; | |
// Elide work on the edge when matrix dimension < SRAM block dimension. | |
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd); | |
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd); | |
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd; | |
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; | |
constant ushort A_sram_length = M_simd / 8; | |
constant ushort B_sram_length = N_simd / 8; | |
constant ushort A_sram_offset = 0; | |
constant ushort B_sram_offset = A_sram_offset + A_sram_length; | |
constant ushort C_sram_offset = B_sram_offset + B_sram_length; | |
// MARK: - Utilities | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[M_simd][8] | |
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// A_sram[8][N_simd] | |
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram( | |
thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin | |
) { | |
// C_sram[M_simd][N_simd] | |
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); | |
} | |
template <typename T> | |
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram, | |
device T *C, bool m_is_edge, bool n_is_edge) | |
{ | |
const ushort m_start = (m_is_edge) ? M_modulo : 0; | |
const ushort n_start = (n_is_edge) ? N_modulo : 0; | |
const ushort m_end = (m_is_edge) ? M_simd : M_modulo; | |
const ushort n_end = (n_is_edge) ? N_simd : N_modulo; | |
#pragma clang loop unroll(full) | |
for (ushort m = m_start; m < m_end; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = n_start; n < n_end; n += 8) { | |
ushort2 origin(n, m); | |
C_sram(sram, origin)->store(C, N, origin); | |
} | |
} | |
} | |
// MARK: - Kernels | |
kernel void sgemm(device float *A [[buffer(0)]], | |
device float *B [[buffer(1)]], | |
device float *C [[buffer(2)]], | |
uint3 gid [[threadgroup_position_in_grid]], | |
ushort lane_id [[thread_index_in_simdgroup]]) | |
{ | |
simdgroup_matrix_storage<float> sram[1024]; | |
ushort2 offset_in_simd = simdgroup_matrix_storage<float>::offset(lane_id); | |
ushort2 A_offset(0, gid.y * M_simd); | |
ushort2 B_offset(gid.x * N_simd, 0); | |
A_offset += offset_in_simd; | |
B_offset += offset_in_simd; | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<float>(0); | |
} | |
} | |
for (uint k = 0; k < K; k += 8) { | |
auto A_src = simdgroup_matrix_storage<float>::apply_offset( | |
A, A_leading_dim, uint2(A_offset), A_trans); | |
auto B_src = simdgroup_matrix_storage<float>::apply_offset( | |
B, B_leading_dim, uint2(B_offset), B_trans); | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
ushort2 origin(0, m); | |
A_sram(sram, origin)->load(A_src, A_leading_dim, origin, A_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
ushort2 origin(n, 0); | |
B_sram(sram, origin)->load(B_src, B_leading_dim, origin, B_trans); | |
} | |
#pragma clang loop unroll(full) | |
for (ushort m = 0; m < M_padded; m += 8) { | |
auto A = A_sram(sram, ushort2(0, m)); | |
#pragma clang loop unroll(full) | |
for (ushort n = 0; n < N_padded; n += 8) { | |
auto B = B_sram(sram, ushort2(n, 0)); | |
auto C = C_sram(sram, ushort2(n, m)); | |
C->multiply(*A, *B); | |
} | |
} | |
A_offset.x += 8; | |
B_offset.y += 8; | |
} | |
// WARNING: M and N must be divisible by 8. | |
{ | |
uint2 matrix_origin = uint2(B_offset.x, A_offset.y); | |
auto C_src = simdgroup_matrix_storage<float>::apply_offset( | |
C, N, matrix_origin); | |
store_accumulator(sram, C_src, false, false); | |
const uint M_edge_floor = M - M % M_simd; | |
const uint N_edge_floor = N - N % N_simd; | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(sram, C_src, true, false); | |
} | |
if (matrix_origin.x < N_edge_floor) { | |
store_accumulator(sram, C_src, false, true); | |
if (matrix_origin.y < M_edge_floor) { | |
store_accumulator(sram, C_src, true, true); | |
} | |
} | |
} | |
} | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment