Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active August 1, 2023 14:47
Show Gist options
  • Save philipturner/939d4ffda26e66f10a142c82d8d498e9 to your computer and use it in GitHub Desktop.
Save philipturner/939d4ffda26e66f10a142c82d8d498e9 to your computer and use it in GitHub Desktop.
//
// ContentView.swift
// SIMDFuturesA15
//
// Created by Philip Turner on 6/9/23.
//
import SwiftUI
import Foundation
import QuartzCore
import MetalPerformanceShadersGraph
import Accelerate
struct ContentView: View {
var body: some View {
VStack {
Image(systemName: "globe")
.imageScale(.large)
.foregroundStyle(.tint)
Text({ () -> String in
testMFALibrary()
exit(0)
return "Hello, world!"
}())
}
.padding()
}
}
// Load and profile the MFA library, compare against MPSGraph for 2048x2048x2048
// square matrices of each precision. Report the number of GFLOPS achieved.
var matrixDim: Int = 2048
var numTrials: Int = 5
var numRepetitionsANE: Int = 5
let verbose = false
func testMFALibrary() {
matrixDim = 256
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 512
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 768
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 1024
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: false, useANE: false, forceRunSync: false)
srand48(0)
profile(useHalf: true, useANE: false, forceRunSync: false)
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 1024 + 256
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 1024 + 512
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 1024 + 512 + 256
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 2048
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: false, useANE: false, forceRunSync: false)
srand48(0)
profile(useHalf: true, useANE: false, forceRunSync: false)
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
matrixDim = 4096
numTrials = 3
numRepetitionsANE = 2
print()
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)")
srand48(0)
profile(useHalf: true, useANE: true, forceRunSync: true)
}
func profile(useHalf: Bool, useANE: Bool, forceRunSync: Bool) {
enum Variant {
case blocked // fallback to ensure this works
case streamk
}
let variant = Variant.streamk
let device = MTLCreateSystemDefaultDevice()!
let commandQueue = device.makeCommandQueue()!
let libraryURL = Bundle.main.url(
forResource: "libMetalFlashAttention", withExtension: "metallib")!
let library = try! device.makeLibrary(URL: libraryURL)
// MARK: - Universal Constants
let constants = MTLFunctionConstantValues()
var device_architecture: UInt16 = 1
var device_generation: UInt16 = 8
var device_cores: UInt16 = 5
var atomic_algorithm: UInt16 = 1 // Nanite
constants.setConstantValue(&device_architecture, type: .ushort, index: 1)
constants.setConstantValue(&device_generation, type: .ushort, index: 2)
constants.setConstantValue(&device_cores, type: .ushort, index: 3)
constants.setConstantValue(&atomic_algorithm, type: .ushort, index: 4)
// MARK: - Block Sizes
// Compiled with 64x64x16, 2x2 simds/threadgroup.
var MNK: UInt32 = .init(matrixDim)
constants.setConstantValue(&MNK, type: .uint, index: 100)
constants.setConstantValue(&MNK, type: .uint, index: 101)
constants.setConstantValue(&MNK, type: .uint, index: 102)
// MARK: - Internal
// For some reason, newer versions of Xcode require this to be set.
var _simdgroup_use_threadgroup: Bool = true
constants.setConstantValue(
&_simdgroup_use_threadgroup, type: .bool, index: 300)
// This should eventually be removed, so the user doesn't have to access the
// private API.
var simds_per_threadgroup: SIMD2<UInt16> = .init(2, 2)
constants.setConstantValue(&simds_per_threadgroup, type: .ushort2, index: 301)
var tg_size_mn: UInt16 = 4
switch variant {
case .blocked:
constants.setConstantValue(&tg_size_mn, type: .ushort, index: 200)
case .streamk:
break
}
// MARK: - PSO Creation
var name: String
switch variant {
case .blocked:
name = useHalf ? "hgemm_blocked" : "sgemm_blocked"
case .streamk:
name = useHalf ? "hgemm" : "sgemm"
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let pipeline = try! device.makeComputePipelineState(function: function)
let elementSize = useHalf ? 2 : 4
var matrices: [MTLBuffer] = []
for _ in 0..<3 {
let matrixSize = matrixDim * matrixDim * elementSize
let buffer = device.makeBuffer(length: matrixSize)!
if useHalf {
let casted = buffer.contents().assumingMemoryBound(to: Float16.self)
matrixInitRandom(
matrix: casted, dim1: .init(matrixDim), dim2: .init(matrixDim)
) {
Float16($0)
}
} else {
let casted = buffer.contents().assumingMemoryBound(to: Float32.self)
matrixInitRandom(
matrix: casted, dim1: .init(matrixDim), dim2: .init(matrixDim)
) {
Float32($0)
}
}
matrices.append(buffer)
}
// Profile GFLOPS on MFA.
var numThreadgroups: MTLSize
var threadgroupSize: MTLSize
switch variant {
case .blocked:
let tg_width: Int = 16
let execution_width = Int(tg_size_mn) * tg_width
let grid_width = matrixDim / execution_width
numThreadgroups = MTLSizeMake(grid_width, grid_width, 1)
threadgroupSize = MTLSizeMake(tg_width, tg_width, 1)
case .streamk:
numThreadgroups = MTLSizeMake((matrixDim / 64), (matrixDim / 64), 1)
threadgroupSize = MTLSizeMake(128, 1, 1)
}
let precisionRepr = useHalf ? "f16" : "f32"
let precisionReprMPS = useHalf ? (
useANE ? "f16 (either ANE or GPU tensor cores)" : "f16") : "f32"
var commandBuffers: [MTLCommandBuffer] = []
for _ in 0..<numTrials {
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(matrices[0], offset: 0, index: 0)
encoder.setBuffer(matrices[1], offset: 0, index: 1)
encoder.setBuffer(matrices[2], offset: 0, index: 2)
encoder.dispatchThreadgroups(
numThreadgroups,
threadsPerThreadgroup: threadgroupSize)
encoder.endEncoding()
commandBuffer.commit()
commandBuffers.append(commandBuffer)
}
print()
print("Metal FlashAttention: '\(precisionRepr)'")
if verbose {
print("Variant: '\(name)'")
}
var minLatency: Double = 1000
for commandBuffer in commandBuffers {
commandBuffer.waitUntilCompleted()
let latencySeconds = commandBuffer.gpuEndTime - commandBuffer.gpuStartTime
let latencyMicroseconds = Int(latencySeconds / 1e-6)
if verbose {
print("us: \(latencyMicroseconds)")
}
minLatency = min(latencySeconds, minLatency)
}
let floatOps = matrixDim * matrixDim * matrixDim * 2
let flops = Double(floatOps) / Double(minLatency)
let gflops = Int(flops / 1e9)
print("GFLOPS: \(gflops)")
let matrixSize = matrixDim * matrixDim * elementSize
let outputsCopy = malloc(matrixSize)!
defer { free(outputsCopy) }
memcpy(outputsCopy, matrices[2].contents(), matrixSize)
if verbose {
if useHalf {
let ptr = outputsCopy.assumingMemoryBound(to: Float16.self)
print("First element of MFA output is currently: \(ptr[0])")
print("65536th element of MFA output is currently: \(ptr[65536])")
} else {
let ptr = outputsCopy.assumingMemoryBound(to: Float32.self)
print("First element of MFA output is currently: \(ptr[0])")
print("65536th element of MFA output is currently: \(ptr[65536])")
}
}
// Now, profile against MPSGraph and compare correctness.
// MARK: - GEMM_initialize
let graph = MPSGraph()
class GPUContext {
var device: MTLDevice
var commandQueue: MTLCommandQueue
// no library because we only call into MPS
var graphDevice: MPSGraphDevice
static let shared = GPUContext()
init() {
let devices = [MTLCreateSystemDefaultDevice()!]
var bestDevice: MTLDevice?
for device in devices {
if device.supportsFamily(.metal3) {
// if bestDevice != nil && device.isLowPower {
// continue
// }
bestDevice = device
}
}
guard let bestDevice else {
preconditionFailure("Could not locate a Metal 3 GPU.")
}
self.device = bestDevice
self.commandQueue = device.makeCommandQueue()!
self.graphDevice = MPSGraphDevice(mtlDevice: device)
}
}
func reduceShape(_ array: [NSNumber]) -> Int {
return array.reduce(Int(1)) { lhs, rhs in
return lhs * rhs.intValue
}
}
class Tensor {
var tensor: MPSGraphTensor
var tensorData: MPSGraphTensorData
var buffer: MTLBuffer
var bufferSize: Int
init(tensor: MPSGraphTensor, dataType: MPSDataType) {
self.tensor = tensor
var dataSize: Int
switch dataType {
case .float16: dataSize = 2
case .float32: dataSize = 4
default:
fatalError("Data type with raw value '\(dataType.rawValue)' unsupported.")
}
self.bufferSize = reduceShape(tensor.shape!) * dataSize
let padding: Int = 512
let paddedSize = ~(padding - 1) & (bufferSize + padding - 1)
let device = GPUContext.shared.device
self.buffer = device.makeBuffer(
length: paddedSize, options: .storageModeShared)!
self.tensorData = MPSGraphTensorData(
buffer, shape: tensor.shape!, dataType: dataType)
}
func shapedType() -> MPSGraphShapedType {
MPSGraphShapedType(shape: tensor.shape!, dataType: tensorData.dataType)
}
}
func makeShape(_ array: Int64...) -> [NSNumber] {
return array.map(NSNumber.init(value:))
}
let isHalf: Int8 = useHalf ? 1 : 0
let m: Int64 = .init(matrixDim)
let n: Int64 = .init(matrixDim)
let k: Int64 = .init(matrixDim)
let dataType: MPSDataType = (isHalf != 0) ? .float16 : .float32
let _lhs = graph.placeholder(
shape: makeShape(m, k), dataType: dataType, name: "lhs")
let _rhs = graph.placeholder(
shape: makeShape(k, n), dataType: dataType, name: "rhs")
let _out = graph.matrixMultiplication(
primary: _lhs, secondary: _rhs, name: "out")
precondition(_out.shape!.first! == _lhs.shape!.first!)
precondition(_out.shape!.last! == _rhs.shape!.last!)
let lhs = Tensor(tensor: _lhs, dataType: dataType)
let rhs = Tensor(tensor: _rhs, dataType: dataType)
let out = Tensor(tensor: _out, dataType: dataType)
let feedsDict = [
lhs.tensor: lhs,
rhs.tensor: rhs
]
let targetsDict = [
out.tensor: out
]
func mapFeeds(
_ dict: [MPSGraphTensor: Tensor]
) -> [MPSGraphTensor: MPSGraphShapedType] {
var output: [MPSGraphTensor: MPSGraphShapedType] = [:]
for (key, value) in dict {
output[key] = value.shapedType()
}
return output
}
let feeds = mapFeeds(feedsDict)
func makeGraphDesc() -> MPSGraphCompilationDescriptor {
let desc = MPSGraphCompilationDescriptor()
desc.optimizationLevel = useANE ? .level1 : .level0
desc.optimizationProfile = .performance
return desc
}
let graphExec = graph.compile(
with: GPUContext.shared.graphDevice,
feeds: feeds,
targetTensors: [out.tensor],
targetOperations: nil,
compilationDescriptor: makeGraphDesc())
// MARK: - GEMM_setInputs
struct Pair {
var data: UnsafeRawPointer
var tensor: Tensor
}
let pairs: [Pair] = [
Pair(data: .init(matrices[0].contents()), tensor: lhs),
Pair(data: .init(matrices[1].contents()), tensor: rhs)
]
for pair in pairs {
let src = pair.data
let dst = pair.tensor.buffer.contents()
let len = pair.tensor.bufferSize
memcpy(dst, src, len)
}
// MARK: - GEMM_runAsync
func mapArguments(
_ dict: [MPSGraphTensor: Tensor],
order: [MPSGraphTensor]
) -> [MPSGraphTensorData] {
return order.map { dict[$0]!.tensorData }
}
func dispatchCommandBuffer(
gpuTime: UnsafeMutablePointer<Double>?,
checkError: Int8,
wait: Int8,
graphExec: MPSGraphExecutable,
feedsDict: [MPSGraphTensor: Tensor],
targetsDict: [MPSGraphTensor: Tensor],
synchronized: inout Bool
) {
let commandQueue = GPUContext.shared.commandQueue
if forceRunSync {
let start = CACurrentMediaTime()
for _ in 0..<numRepetitionsANE - 1 {
graphExec.runAsync(
with: commandQueue,
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!),
results: mapArguments(targetsDict, order: graphExec.targetTensors!),
executionDescriptor: nil)
}
graphExec.run(
with: commandQueue,
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!),
results: mapArguments(targetsDict, order: graphExec.targetTensors!),
executionDescriptor: nil)
let end = CACurrentMediaTime()
gpuTime!.pointee = (end - start) / Double(numRepetitionsANE)
synchronized = true
} else {
let commandBuffer = commandQueue.makeCommandBuffer()!
let mpsCommandBuffer = MPSCommandBuffer(commandBuffer: commandBuffer)
let desc = MPSGraphExecutableExecutionDescriptor()
if gpuTime != nil || checkError != 0 {
desc.completionHandler = { _, error in
if checkError != 0 {
if let error {
print(error.localizedDescription)
fatalError(error.localizedDescription)
}
precondition(commandBuffer.status == .completed, "Unexpected status.")
}
if let gpuTime {
let startTime = commandBuffer.gpuStartTime
let endTime = commandBuffer.gpuEndTime
gpuTime.pointee = endTime - startTime
}
}
}
graphExec.encode(
to: mpsCommandBuffer,
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!),
results: mapArguments(targetsDict, order: graphExec.targetTensors!),
executionDescriptor: desc)
commandBuffer.commit()
if wait != 0 {
commandBuffer.waitUntilCompleted()
synchronized = true
} else {
synchronized = false
}
}
}
let timesBuffer = malloc(numTrials * 8)!.assumingMemoryBound(to: Double.self)
defer { free(timesBuffer) }
var synchronized: Bool = true
for i in 0..<numTrials {
let wait: Int8 = (i == numTrials - 1) ? 1 : 0
let gpuTime = timesBuffer + i
dispatchCommandBuffer(
gpuTime: gpuTime,
checkError: 1,
wait: wait,
graphExec: graphExec,
feedsDict: feedsDict,
targetsDict: targetsDict,
synchronized: &synchronized)
}
// Report performance of MPS.
do {
if verbose {
print()
}
print("Metal Performance Shaders: '\(precisionReprMPS)'")
var minLatency: Double = 1000
for i in 0..<numTrials {
let latencySeconds = timesBuffer[i]
let latencyMicroseconds = Int(latencySeconds / 1e-6)
if verbose {
print("us: \(latencyMicroseconds)")
}
minLatency = min(latencySeconds, minLatency)
}
let floatOps = matrixDim * matrixDim * matrixDim * 2
let flops = Double(floatOps) / Double(minLatency)
let gflops = Int(flops / 1e9)
print("GFLOPS: \(gflops)")
if verbose {
if useHalf {
let ptr = outputsCopy.assumingMemoryBound(to: Float16.self)
print("First element of MFA output is currently: \(ptr[0])")
print("65536th element of MFA output is currently: \(ptr[65536])")
} else {
let ptr = outputsCopy.assumingMemoryBound(to: Float32.self)
print("First element of MFA output is currently: \(ptr[0])")
print("65536th element of MFA output is currently: \(ptr[65536])")
}
}
}
// MARK: - GEMM_compare
// Compute Euclidean distance, which tolerates small differences due to
// numerical error.
func euclideanDistance(
n: Int,
x: UnsafePointer<Float>,
y: UnsafePointer<Float>,
difference: UnsafeMutablePointer<Float>
) -> Float {
memcpy(difference, x, n * 4)
var n_copy = Int32(n)
var a = Float(-1)
var inc = Int32(1)
var inc_copy = inc
// Find x + (-1 * y)
saxpy_(
&n_copy, &a, UnsafeMutablePointer(mutating: y), &inc, difference,
&inc_copy)
// Find ||x - y||
return Float(snrm2_(&n_copy, difference, &inc))
}
func printComparisonError(
distance: Float,
tolerance: Float,
n: Int,
x: UnsafePointer<Float>,
y: UnsafePointer<Float>,
difference: UnsafePointer<Float>
) {
// How many elements to print before stopping.
let diagnosticElems = 16
print()
print("""
Vectors did not match: Euclidean distance '\(distance)' > \
tolerance '\(tolerance)'.
""")
var failedElements = 0
for i in 0..<Int(n) {
if abs(x[i] - y[i]) > 0.001 * distance || x[i].isNaN || y[i].isNaN {
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])")
failedElements += 1
if failedElements >= diagnosticElems {
break
}
}
}
if failedElements == 0 {
print("Could not find elements with enough separation.")
for index in 0..<Int(diagnosticElems) {
let i = (index == 0) ? 0 : Int.random(in: 0..<n)
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])")
}
}
}
do {
let _lhs = matrices[0].contents()
let _rhs = matrices[1].contents()
let _out = outputsCopy
let printError: Int8 = 1
// Total rounding error scales with the square root of K and linearly with
// M/N. Account for this when checking the output matrix.
// TODO: Does transposed multiplication need a different heuristic?
func matches(
_ input1: UnsafeRawPointer,
_ input2: Tensor,
useK: Bool
) -> Bool {
let dataType = input2.tensorData.dataType
switch dataType {
case .float16: break
case .float32: break
default:
fatalError(
"Data type with raw value '\(dataType.rawValue)' unsupported.")
}
// Do not mutate 'x' and 'y' after filling them.
var bufferSize: Int
var n: Int
var x: UnsafeMutablePointer<Float>
var y: UnsafeMutablePointer<Float>
if dataType == .float16 {
let bufferSize_f16 = input2.bufferSize
let bufferSize_f32 = bufferSize_f16 * 2
bufferSize = bufferSize_f32
n = bufferSize_f32 / MemoryLayout<Float>.stride
x = .allocate(capacity: n)
y = .allocate(capacity: n)
// Partially sourced from:
// https://github.com/hollance/TensorFlow-iOS-Example/blob/master/VoiceMetal/VoiceMetal/Float16.swift
func copy(dst: UnsafeMutableRawPointer, src: UnsafeMutableRawPointer) {
var bufferFloat16 = vImage_Buffer(
data: src, height: 1, width: UInt(n), rowBytes: bufferSize_f16)
var bufferFloat32 = vImage_Buffer(
data: dst, height: 1, width: UInt(n), rowBytes: bufferSize_f32)
let error = vImageConvert_Planar16FtoPlanarF(
&bufferFloat16, &bufferFloat32, 0)
if error != kvImageNoError {
fatalError(
"Encountered error code \(error) while converting F16 to F32.")
}
}
copy(dst: x, src: .init(mutating: input1))
copy(dst: y, src: input2.buffer.contents())
} else {
bufferSize = input2.bufferSize
n = bufferSize / MemoryLayout<Float>.stride
x = .allocate(capacity: n)
y = .allocate(capacity: n)
memcpy(x, input1, bufferSize)
memcpy(y, input2.buffer.contents(), bufferSize)
}
defer {
x.deallocate()
y.deallocate()
}
let difference: UnsafeMutablePointer<Float> = .allocate(capacity: n)
defer {
difference.deallocate()
}
let distance = euclideanDistance(n: n, x: x, y: y, difference: difference)
// How many elements to print before stopping.
var tolerance: Float
if useK {
let mk = lhs.tensor.shape!
let kn = rhs.tensor.shape!
precondition(mk[1] == kn[0], "K dimension does not match.")
let M = mk[0].intValue
let N = kn[1].intValue
let K = mk[1].intValue
let expectedDeviation = Float(M * N) * sqrt(Float(K))
if input2.tensorData.dataType == .float32 {
tolerance = max(0.001, 3e-7 * expectedDeviation)
// tolerance = max(0.001, 3e-9 * expectedDeviation)
} else {
tolerance = max(0.01, 5e-3 * expectedDeviation)
// tolerance = max(0.01, 5e-4 * expectedDeviation)
}
} else {
if input2.tensorData.dataType == .float32 {
tolerance = 0.001
} else {
tolerance = 0.01
}
}
if useK {
if verbose {
print()
}
print("MFA vs MPS Euclidean Distance: \(distance)")
}
if distance < tolerance {
return true
} else {
if printError != 0 {
printComparisonError(
distance: distance,
tolerance: tolerance,
n: n,
x: x,
y: y,
difference: difference)
}
return false
}
}
if matches(_lhs, lhs, useK: false),
matches(_rhs, rhs, useK: false),
matches(_out, out, useK: true) {
// return 1
} else {
if printError == 1 {
func printShape(_ tensor: Tensor, name: String) {
print("Tensor '\(name)' has shape \(tensor.tensor.shape!).")
}
printShape(lhs, name: "lhs")
printShape(rhs, name: "rhs")
printShape(out, name: "out")
}
// return 0
}
}
}
func matrixInitRandom<T: BinaryFloatingPoint>(
matrix: UnsafeMutablePointer<T>, dim1: Int, dim2: Int,
_ initialize: (Double) -> T
) {
for i in 0..<dim1 * dim2 {
let element = initialize(drand48())
matrix[i] = element
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment