Last active
July 22, 2024 23:47
-
-
Save philipturner/d408351d68b5b1701bb651d4542e26e6 to your computer and use it in GitHub Desktop.
Calculate the number of floating-point operations in Stable Diffusion, and how those operations are distributed among layers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// main.swift | |
// CalculateDiffusion | |
// | |
// Created by Philip Turner on 6/2/23. | |
// | |
import Foundation | |
import QuartzCore | |
import MetalPerformanceShadersGraph | |
import Accelerate | |
// Parse the log from NNC: | |
// https://gist.github.com/liuliu/f80ceaebe36b177c4d50598ad27c782b | |
// Use this visualization as a reference (if necessary): | |
// https://nn.labml.ai/diffusion/stable_diffusion/model/unet.html | |
// MARK: - Presentation Style | |
// Whether to force the results to pretend batch size was 1. | |
let forceBatch1: Bool = false | |
let expectedBatchSize: Int = 2 | |
// Whether to show the theoretical minimum latency for 30 steps. | |
let showLatency: Bool = true | |
let flopsForLatency: Float = 10.616832e12 | |
let ramBandwidthForLatency: Float = 409.6e9 | |
let slcBandwidthForLatency: Float = 2 * 409.6e9 | |
// Whether to convert total application latencies to it/s. | |
let showItersPerSecond: Bool = false | |
// How to present large integers | |
let showConciseMagnitude: Bool = false | |
// Whether to show float ops for each operation class. | |
let showExactComputeCost: Bool = false | |
// Whether to print the total number of operations involved in self-attention. | |
let showSelfAttention: Bool = false | |
// Whether to print for formatting in a Markdown document. | |
let cleanDisplay: Bool = true | |
let cleanDisplayWinograd: Bool = false | |
// MARK: - Simulation Hyperparameters | |
// Whether to increase the float ops to simulate real-world performance. | |
let simulateUtilization: Bool = true | |
let simulationFramework: MatrixUtilization.Framework = .mps | |
let simulationPrecision: MatrixUtilization.Precision = .f16 | |
// Whether to simulate usage of the monolithic MFA kernel, which is slower than | |
// using fine-tuned block sizes. It also doesn't support fused transposes. | |
let useMonolithicKernel: Bool = true | |
// Whether to just run GEMM in MFA and Winograd in MPS. | |
let useWinogradMPS: Bool = true | |
// Whether to run MPSGraph benchmarks for Stable Diffusion convolutions. | |
let benchmarkMPSConvolution: Bool = false | |
let benchmarkMPSMFAGEMM: Bool = false | |
let benchmarkMPSSoftmax: Bool = false | |
let benchmarkMPSTranspose: Bool = false | |
let mpsGraphOptimizationLevel1: Bool = true | |
// Parameters for the GEMM benchmark | |
let benchmarkGEMM_useEnsemble: Bool = false | |
// The block size to use for monolithic. Change this to validate the | |
// correctness of each block size. | |
let benchmarkGEMM_monolithicVariant: BlockSizeVariant = .mfa32x32 | |
enum BlockSizeVariant: CaseIterable { | |
case mps | |
case mfa16x16 | |
case mfa32x32 | |
case mfa48x48 | |
func repr() -> String { | |
switch self { | |
case .mps: | |
return "MPS" | |
case .mfa16x16: | |
return "MFA (M=16)x(N=16)x(K=48,64)" | |
case .mfa32x32: | |
return "MFA (M=32)x(N=32)x(K=32)" | |
case .mfa48x48: | |
return "MFA (M=48)x(N=48)x(K=24,32)" | |
} | |
} | |
} | |
// MARK: - Temporary Extraction of Attention Operations | |
var num2x4096x320x320_GEMM = 0 // done | |
var num2x4096x320x320_bias_GEMM = 0 // done | |
var num2x4096x320_SCALAR_MUL = 0 | |
var num2x8x4096x40_TRANSPOSE = 0 | |
var num2x4096x8x40_TRANSPOSE = 0 | |
var num2x4096x320_ADD = 0 | |
var num1x4096x4096x40_GEMM = 0 | |
var num1x4096x40x4096_GEMM = 0 | |
var num4096x4096_SOFTMAX = 0 | |
func countSelfAttention(operation: Operation) { | |
switch operation.name { | |
case "GEMM": | |
if operation.inputs.count >= 2 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == [2, 4096, 320], | |
operation.inputs[1].dimensions == [320, 320], | |
operation.outputs[0].dimensions == [2, 4096, 320] { | |
if operation.inputs.count == 2 { | |
num2x4096x320x320_GEMM += 1 | |
return | |
} else if operation.inputs.count == 3 { | |
if operation.inputs[2].dimensions == [320] { | |
num2x4096x320x320_bias_GEMM += 1 | |
return | |
} | |
} | |
fatalError("Unrecognized operation: \(operation)") | |
} | |
if operation.inputs.count == 2 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == [1, 4096, 40], | |
operation.inputs[1].dimensions == [1, 4096, 40], | |
operation.outputs[0].dimensions == [1, 4096, 4096] { | |
num1x4096x4096x40_GEMM += 1 | |
return | |
} | |
if operation.inputs[0].dimensions == [1, 4096, 4096], | |
operation.inputs[1].dimensions == [1, 4096, 40], | |
operation.outputs[0].dimensions == [1, 4096, 40] { | |
num1x4096x40x4096_GEMM += 1 | |
return | |
} | |
} | |
} | |
break | |
case "SCALAR_MUL": | |
if operation.inputs.count == 1 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == [2, 4096, 320], | |
operation.outputs[0].dimensions == [2, 4096, 320] { | |
num2x4096x320_SCALAR_MUL += 1 | |
return | |
} | |
} | |
break | |
case "TRANSPOSE": | |
if operation.inputs.count == 1 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == [2, 4096, 8, 40], | |
operation.outputs[0].dimensions == [2, 8, 4096, 40] { | |
num2x8x4096x40_TRANSPOSE += 1 | |
return | |
} | |
if operation.inputs[0].dimensions == [2, 8, 4096, 40], | |
operation.outputs[0].dimensions == [2, 4096, 8, 40] { | |
num2x4096x8x40_TRANSPOSE += 1 | |
return | |
} | |
} | |
break | |
case "ADD": | |
if operation.inputs.count == 2 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == operation.inputs[1].dimensions, | |
operation.inputs[1].dimensions == [2, 4096, 320], | |
operation.outputs[0].dimensions == operation.inputs[0].dimensions { | |
num2x4096x320_ADD += 1 | |
return | |
} | |
} | |
break | |
case "SOFTMAX": | |
if operation.inputs.count == 1 && operation.outputs.count == 1 { | |
if operation.inputs[0].dimensions == [4096, 4096], | |
operation.outputs[0].dimensions == [4096, 4096] { | |
num4096x4096_SOFTMAX += 1 | |
return | |
} | |
} | |
break | |
default: | |
break | |
} | |
} | |
func printSelfAttention() { | |
let sum = [ | |
num2x4096x320x320_GEMM, | |
num2x4096x320x320_bias_GEMM, | |
num2x4096x320_SCALAR_MUL, | |
num2x8x4096x40_TRANSPOSE, | |
num2x4096x8x40_TRANSPOSE, | |
num2x4096x320_ADD, | |
num1x4096x4096x40_GEMM, | |
num1x4096x40x4096_GEMM, | |
num4096x4096_SOFTMAX | |
].reduce(0, +) | |
print() | |
print("Kernel dispatches with MPS: \(sum)") | |
print("num2x4096x320x320_GEMM - \(num2x4096x320x320_GEMM)") | |
print("num2x4096x320x320_bias_GEMM - \(num2x4096x320x320_bias_GEMM)") | |
print("num2x4096x320_SCALAR_MUL - \(num2x4096x320_SCALAR_MUL)") | |
print("num2x8x4096x40_TRANSPOSE - \(num2x8x4096x40_TRANSPOSE)") | |
print("num2x4096x8x40_TRANSPOSE - \(num2x4096x8x40_TRANSPOSE)") | |
print("num2x4096x320_ADD - \(num2x4096x320_ADD)") | |
print("num1x4096x4096x40_GEMM - \(num1x4096x4096x40_GEMM)") | |
print("num1x4096x40x4096_GEMM - \(num1x4096x40x4096_GEMM)") | |
print("num4096x4096_SOFTMAX - \(num4096x4096_SOFTMAX)") | |
let numAttentionLayers = num1x4096x4096x40_GEMM / 8 / 2 | |
print() | |
print("Kernel dispatches with MFA: \(3 * numAttentionLayers)") | |
print("2x3x(M=320)x(N=4096)x(K=320)_HGEMM_BATCHED: \(numAttentionLayers)") | |
print("2x8x(D=40)x(R=4096)x(C=4096)_FLASH_ATTENTION_BATCHED: \(numAttentionLayers)") | |
print("2x1x(M=320)x(N=4096)x(K=320)_HGEMM_BATCHED: \(numAttentionLayers)") | |
} | |
// MARK: - Empirical Data | |
struct MatrixUtilization { | |
// proportion of maximum performance | |
var mpsF16: Float | |
var mpsF32: Float | |
var mfaF16: Float | |
var mfaF32: Float | |
enum Framework { | |
case mps | |
case mfa | |
var repr: String { | |
self == .mps ? "MPS" : "MFA" | |
} | |
} | |
enum Precision { | |
case f16 | |
case f32 | |
var repr: String { | |
self == .f16 ? "F16" : "F32" | |
} | |
} | |
// Print text that I can use directly as source code. | |
func initRepr() -> String { | |
let args: [(String, Float)] = [ | |
("mpsF16", mpsF16), | |
("mpsF32", mpsF32), | |
("mfaF16", mfaF16), | |
("mfaF32", mfaF32), | |
] | |
let argReprs = args.map { arg in | |
arg.0 + ": " + String(format: "%.3f", arg.1) | |
} | |
return """ | |
.init( | |
\(argReprs[0]), \(argReprs[1]), \(argReprs[2]), \(argReprs[3])) | |
""" | |
} | |
} | |
//// TODO: Profile most of the matrix sizes in Stable Diffusion. This only applies | |
//// to GEMM, and underestimates inefficiency of convolutions. You can control for | |
//// this disparity by referencing the latency of ConvGEMM instead of Winograd. | |
//// | |
//// size = SIMD3(M, N, K) | |
//var matrixSpeeds: [SIMD3<Int>: MatrixUtilization] = [ | |
// SIMD3(1280, 4096, 320) : MatrixUtilization( | |
// mpsF16: 0.691, mpsF32: 0.766, mfaF16: 0.879, mfaF32: 0.814), | |
// SIMD3(1024, 2560, 640) : MatrixUtilization( | |
// mpsF16: 0.700, mpsF32: 0.773, mfaF16: 0.869, mfaF32: 0.800), | |
// SIMD3(4096, 4096, 40) : MatrixUtilization( | |
// mpsF16: 0.358, mpsF32: 0.405, mfaF16: 0.781, mfaF32: 0.605), | |
// SIMD3(4096, 40, 4096) : MatrixUtilization( | |
// mpsF16: 0.108, mpsF32: 0.108, mfaF16: 0.702, mfaF32: 0.527), | |
// SIMD3(1024, 1024, 80) : MatrixUtilization( | |
// mpsF16: 0.425, mpsF32: 0.492, mfaF16: 0.522, mfaF32: 0.496), | |
// SIMD3(1024, 80, 1024) : MatrixUtilization( | |
// mpsF16: 0.169, mpsF32: 0.146, mfaF16: 0.512, mfaF32: 0.433) | |
//] | |
//if useMonolithicKernel { | |
// matrixSpeeds[SIMD3(1280, 4096, 320)]!.mfaF16 = 0.83 | |
// matrixSpeeds[SIMD3(1024, 2560, 640)]!.mfaF16 = 0.82 | |
// matrixSpeeds[SIMD3(4096, 4096, 40)]!.mfaF16 = 0.62 | |
// matrixSpeeds[SIMD3(4096, 40, 4096)]!.mfaF16 = 0.50 | |
// matrixSpeeds[SIMD3(1024, 1024, 80)]!.mfaF16 = 0.54 | |
// matrixSpeeds[SIMD3(1024, 80, 1024)]!.mfaF16 = 0.43 | |
// | |
// matrixSpeeds[SIMD3(1280, 4096, 320)]!.mfaF32 = 0.75 | |
// matrixSpeeds[SIMD3(1024, 2560, 640)]!.mfaF32 = 0.76 | |
// matrixSpeeds[SIMD3(4096, 4096, 40)]!.mfaF32 = 0.50 | |
// matrixSpeeds[SIMD3(4096, 40, 4096)]!.mfaF32 = 0.36 | |
// matrixSpeeds[SIMD3(1024, 1024, 80)]!.mfaF32 = 0.48 | |
// matrixSpeeds[SIMD3(1024, 80, 1024)]!.mfaF32 = 0.40 | |
//} | |
//matrixSpeeds[SIMD3(4096, 320, 320)] = MatrixUtilization( | |
// mpsF16: 0.62, mpsF32: 0.68, mfaF16: 0.79, mfaF32: 0.70) | |
//matrixSpeeds[SIMD3(4096, 1713, 40)] = MatrixUtilization( | |
// mpsF16: 0.32, mpsF32: 0.34, mfaF16: 0.52, mfaF32: 0.40) | |
//matrixSpeeds[SIMD3(4096, 40, 1713)] = MatrixUtilization( | |
// mpsF16: 0.19, mpsF32: 0.097, mfaF16: 0.46, mfaF32: 0.40) | |
//matrixSpeeds[SIMD3(4096, 92, 40)] = MatrixUtilization( | |
// mpsF16: 0.072, mpsF32: 0.064, mfaF16: 0.28, mfaF32: 0.21) | |
//matrixSpeeds[SIMD3(4096, 40, 92)] = MatrixUtilization( | |
// mpsF16: 0.075, mpsF32: 0.073, mfaF16: 0.27, mfaF32: 0.19) | |
//matrixSpeeds[SIMD3(1805, 320, 768)] = MatrixUtilization( | |
// mpsF16: 0.51, mpsF32: 0.56, mfaF16: 0.75, mfaF32: 0.63) | |
//matrixSpeeds[SIMD3(1805, 768, 320)] = MatrixUtilization( | |
// mpsF16: 0.64, mpsF32: 0.71, mfaF16: 0.81, mfaF32: 0.67) | |
// | |
//matrixSpeeds[SIMD3(512, 512, 32)] = MatrixUtilization( | |
// mpsF16: 0.14, mpsF32: 0.14, mfaF16: 0.26, mfaF32: 0.20) | |
//matrixSpeeds[SIMD3(512, 32, 512)] = MatrixUtilization( | |
// mpsF16: 0.081, mpsF32: 0.077, mfaF16: 0.082, mfaF32: 0.075) | |
//matrixSpeeds[SIMD3(2048, 2048, 32)] = MatrixUtilization( | |
// mpsF16: 0.40, mpsF32: 0.46, mfaF16: 0.61, mfaF32: 0.50) | |
//matrixSpeeds[SIMD3(2048, 32, 2048)] = MatrixUtilization( | |
// mpsF16: 0.35, mpsF32: 0.32, mfaF16: 0.35, mfaF32: 0.35) | |
//matrixSpeeds[SIMD3(2048, 2048, 40)] = MatrixUtilization( | |
// mpsF16: 0.32, mpsF32: 0.36, mfaF16: 0.56, mfaF32: 0.46) | |
//matrixSpeeds[SIMD3(2048, 40, 2048)] = MatrixUtilization( | |
// mpsF16: 0.058, mpsF32: 0.059, mfaF16: 0.40, mfaF32: 0.37) | |
//matrixSpeeds[SIMD3(2048, 2048, 52)] = MatrixUtilization( | |
// mpsF16: 0.32, mpsF32: 0.35, mfaF16: 0.52, mfaF32: 0.39) | |
//matrixSpeeds[SIMD3(2048, 52, 2048)] = MatrixUtilization( | |
// mpsF16: 0.14, mpsF32: 0.12, mfaF16: 0.49, mfaF32: 0.47) | |
let matrixSpeeds = getMatrixSpeeds(monolithic: useMonolithicKernel) | |
let convSpeeds = generateConvSpeeds() | |
let softmaxSpeeds: [SIMD2<Int>: SoftmaxUtilization] = [ | |
SIMD2(32768, 92): SoftmaxUtilization( | |
mpsF16Bandwidth: 58.6, mpsF32Bandwidth: 116.8), | |
SIMD2(8192, 92): SoftmaxUtilization( | |
mpsF16Bandwidth: 36.0, mpsF32Bandwidth: 75.4), | |
SIMD2(2048, 92): SoftmaxUtilization( | |
mpsF16Bandwidth: 10.1, mpsF32Bandwidth: 19.8), | |
SIMD2(512, 92): SoftmaxUtilization( | |
mpsF16Bandwidth: 2.6, mpsF32Bandwidth: 4.8), | |
SIMD2(32768, 1713): SoftmaxUtilization( | |
mpsF16Bandwidth: 142.5, mpsF32Bandwidth: 180.1), | |
SIMD2(8192, 1713): SoftmaxUtilization( | |
mpsF16Bandwidth: 133.5, mpsF32Bandwidth: 181.8), | |
SIMD2(2048, 1713): SoftmaxUtilization( | |
mpsF16Bandwidth: 95.9, mpsF32Bandwidth: 173.4), | |
SIMD2(512, 1713): SoftmaxUtilization( | |
mpsF16Bandwidth: 45.4, mpsF32Bandwidth: 89.4), | |
SIMD2(4096, 4096): SoftmaxUtilization( | |
mpsF16Bandwidth: 139.9, mpsF32Bandwidth: 172.8), | |
SIMD2(1024, 1024): SoftmaxUtilization( | |
mpsF16Bandwidth: 56.8, mpsF32Bandwidth: 111.5), | |
SIMD2(256, 256): SoftmaxUtilization( | |
mpsF16Bandwidth: 3.5, mpsF32Bandwidth: 7.5), | |
SIMD2(64, 64): SoftmaxUtilization( | |
mpsF16Bandwidth: 0.2, mpsF32Bandwidth: 0.5), | |
] | |
// Transpose speeds: | |
// It is sufficient to model each transpose as taking 58 microseconds (these are | |
// latency-bound and not memory-bound. | |
// Asymptotic maximum utilization measured with a particular framework and | |
// precision, not necessarily the utilization during every run. If no data is | |
// present, simply report maximum utilization. | |
// | |
//var notSimSizes: [SIMD3<Int>: Bool] = [:] | |
//var numNotSimSizes: Int = 0 | |
func simulateMatrixUtilization( | |
// size: SIMD3<Int>, | |
shape: MatrixShape, | |
floatOps: Int, | |
isBatched: Bool, | |
isTransposed: Bool, | |
model: inout Model | |
) -> Int { | |
guard let utilization = matrixSpeeds[shape] else { | |
// if notSimSizes[size] == nil { | |
// print("Not simulated matrix size: \(size)") | |
// numNotSimSizes += 1 | |
// print("Total not simulated: \(numNotSimSizes)") | |
// notSimSizes[size] = true | |
// } | |
return floatOps | |
} | |
model.simulatedOperations += 1 | |
var framework = simulationFramework | |
// Fall back to MPS, as the operation isn't supported by MFA yet. | |
// Assumes monolithic supports batched GEMM, which makes it applicable to most | |
// matrix multiplications in Stable Diffusion. | |
if /*isBatched ||*/ isTransposed { | |
if framework == .mfa && useMonolithicKernel { | |
framework = .mps | |
} | |
} | |
var ratio: Float | |
switch framework { | |
case .mps: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mpsF16 | |
case .f32: | |
ratio = utilization.mpsF32 | |
} | |
case .mfa: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mfaF16 | |
case .f32: | |
ratio = utilization.mfaF32 | |
} | |
} | |
let floatOpsReal = Float(floatOps) / ratio | |
return Int(floatOpsReal) | |
} | |
func simulateConvUtilization( | |
shape: ConvolutionShape | |
) -> Int { | |
let floatOps = shape.floatOperations() | |
guard let utilization = convSpeeds[shape] else { | |
fatalError("Should always have conv speeds.") | |
} | |
model.simulatedOperations += 1 | |
var framework = simulationFramework | |
// Fall back to MPS, as the operation isn't supported by MFA yet. | |
if useWinogradMPS && shape.window == 3 { | |
if framework == .mfa { | |
framework = .mps | |
} | |
} | |
// Monolithic GEMM estimates the assumed impact of the initial integration. | |
if framework == .mfa && useMonolithicKernel { | |
framework = .mps | |
} | |
var ratio: Float | |
switch framework { | |
case .mps: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mpsF16 | |
case .f32: | |
ratio = utilization.mpsF32 | |
} | |
case .mfa: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mfaF16 | |
case .f32: | |
ratio = utilization.mfaF32 | |
} | |
} | |
let floatOpsReal = Float(floatOps) / ratio | |
return Int(floatOpsReal) | |
} | |
func simulateSoftmaxUtilization( | |
shape: SIMD2<Int> | |
) -> Int { | |
let floatOps = 10 * shape[0] * shape[1] | |
guard let utilization = softmaxSpeeds[shape] else { | |
fatalError("Should always have softmax speeds.") | |
} | |
model.simulatedOperations += 1 | |
var framework = simulationFramework | |
// No FlashAttention in the initial integration. | |
if framework == .mfa && useMonolithicKernel { | |
framework = .mps | |
} | |
var ratio: Float | |
switch framework { | |
case .mps: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mpsF16ALU | |
case .f32: | |
ratio = utilization.mpsF32ALU | |
} | |
case .mfa: | |
switch simulationPrecision { | |
case .f16: | |
ratio = utilization.mfaF16ALU | |
case .f32: | |
ratio = utilization.mfaF32ALU | |
} | |
} | |
let floatOpsReal = Float(floatOps) / ratio | |
return Int(floatOpsReal) | |
} | |
// MARK: - Download the Data | |
#if false | |
// Download the GitHub gist from the internet: | |
let gistPath: String = "https://gist.githubusercontent.com/liuliu/f80ceaebe36b177c4d50598ad27c782b" | |
let gistRawFile: String = "f42b6b1c5d5b9be113ffa2a93618fa3f626ad99b/gistfile1.txt" | |
let gistRawPath: String = gistPath + "/raw/" + gistRawFile | |
guard let gistRawURL = URL(string: gistRawPath) else { | |
fatalError("URL was bad.") | |
} | |
#else | |
// Open the Gist from a local file | |
let gistPath: String = "/Users/philipturner/Documents/DrawThings/nnc-captures" | |
let gistRawFile: String = "SD_v1_512x512.txt" | |
let gistRawPath: String = gistPath + "/" + gistRawFile | |
let gistRawURL = URL(filePath: gistRawPath) | |
#endif | |
let start = CACurrentMediaTime() | |
let gistContents = try! String(contentsOf: gistRawURL) | |
let end = CACurrentMediaTime() | |
//print("Download latency: \(latencyRepr(end - start))") | |
// MARK: - Declaration of Helper Functions | |
func startError( | |
_ start: any StringProtocol, | |
_ sequence: any StringProtocol, | |
line: UInt = #line, | |
function: StaticString = #function | |
) -> Never { | |
fatalError( | |
"'\(start)' is not the start of '\(sequence)'.", | |
file: (function), line: line) | |
} | |
func assertExpectedPrefix<T: StringProtocol>( | |
_ prefix: String, | |
from text: T | |
) where T == T.SubSequence { | |
guard text.starts(with: prefix) else { | |
startError(prefix, text) | |
} | |
} | |
func removeExpectedPrefix<T: StringProtocol>( | |
_ prefix: String, | |
from text: inout T | |
) where T == T.SubSequence { | |
assertExpectedPrefix(prefix, from: text) | |
text.removeFirst(prefix.count) | |
} | |
func removeIncluding<T: StringProtocol>( | |
_ prefix: String, | |
from text: inout T | |
) where T == T.SubSequence { | |
while text.starts(with: prefix) { | |
if text.count == 0 { | |
break | |
} | |
text.removeFirst(prefix.count) | |
} | |
} | |
func removeExcluding<T: StringProtocol>( | |
_ prefix: String, | |
from text: inout T | |
) where T == T.SubSequence { | |
while !text.starts(with: prefix) { | |
if text.count == 0 { | |
break | |
} | |
text.removeFirst(prefix.count) | |
} | |
} | |
func extractExcluding<T: StringProtocol>( | |
_ prefix: String, | |
from text: inout T | |
) -> String where T == T.SubSequence { | |
var output: String = "" | |
while !text.starts(with: prefix) { | |
if text.count == 0 { | |
break | |
} | |
output += text.prefix(prefix.count) | |
text = text.dropFirst(prefix.count) | |
} | |
return output | |
} | |
func largeIntegerRepr(_ number: Int) -> String { | |
var _number = number | |
func round(granularity: Int) { | |
var floatNumber = Double(number) | |
floatNumber /= Double(granularity) | |
floatNumber = rint(floatNumber) | |
floatNumber *= Double(granularity) | |
_number = Int(rint(floatNumber)) | |
} | |
var number: Int { _number } | |
if number < 1_000 { | |
return String(number) | |
} else if number < 1_000_000 { | |
let radix = 1_000 | |
round(granularity: 100) | |
let magnitude = showConciseMagnitude ? "K" : "thousand" | |
return "\(number / radix).\(number % radix / 100) \(magnitude)" | |
} else if number < 1_000_000_000 { | |
let radix = 1_000_000 | |
round(granularity: radix / 10) | |
let magnitude = showConciseMagnitude ? "M" : "million" | |
return "\(number / radix).\(number % radix / (radix / 10)) \(magnitude)" | |
} else if number < 1_000_000_000_000 { | |
let radix = 1_000_000_000 | |
round(granularity: radix / 10) | |
let magnitude = showConciseMagnitude ? "B" : "billion" | |
return "\(number / radix).\(number % radix / (radix / 10)) \(magnitude)" | |
} else { | |
let radix = 1_000_000_000_000 | |
round(granularity: radix / 10) | |
let magnitude = showConciseMagnitude ? "T" : "trillion" | |
return "\(number / radix).\(number % radix / (radix / 10)) \(magnitude)" | |
} | |
} | |
func latencyRepr<T: BinaryFloatingPoint>(_ number: T) -> String { | |
var _number = Int(rint(Double(number) * 1e6)) // microseconds | |
func round(granularity: Int) { | |
var floatNumber = Double(number) * 1e6 | |
floatNumber /= Double(granularity) | |
floatNumber = rint(floatNumber) | |
floatNumber *= Double(granularity) | |
_number = Int(rint(floatNumber)) | |
} | |
let f = 10 | |
var number: Int { _number } | |
if number < 1_000 { | |
return "\(number) µs" | |
} else if number < 1_000_000 { | |
let radix = 1_000 | |
round(granularity: radix / f) | |
return "\(number / radix).\(number % radix / (radix / f)) ms" | |
} else if number < 60 * 1_000_000 { | |
let radix = 1_000_000 | |
round(granularity: radix / f) | |
return "\(number / radix).\(number % radix / (radix / f)) s" | |
} else if number < 3_600 * 1_000_000 { | |
let radix = 60 * 1_000_000 | |
round(granularity: radix / f) | |
return "\(number / radix).\(number % radix / (radix / f)) min" | |
} else { | |
let radix = 3_600 * 1_000_000 | |
round(granularity: radix / f) | |
return "\(number / radix).\(number % radix / (radix / f)) hr" | |
} | |
} | |
// MARK: - Declaration of Data Structures | |
struct Tensor { | |
// Starts with the dimension that has the largest stride in memory when you | |
// increment it. The last dimension has stride 1 when you increment it. | |
var dimensions: [Int] | |
// Up to three elements sampled from the tensor. | |
var elements: [Float] | |
// Generate the tensor by parsing a line of text. | |
init<T: StringProtocol>( | |
text: T, | |
idInOperation: Int, // 0-indexed | |
isInput: Bool | |
) where T.SubSequence == Substring { | |
let expectedStart = "|\(isInput ? "->" : "<-") \(idInOperation + 1)." | |
guard text.starts(with: expectedStart) else { | |
startError(expectedStart, text) | |
} | |
var line: Substring = text.dropFirst(expectedStart.count) | |
// Ignore the first number. | |
removeIncluding(" ", from: &line) | |
removeExcluding(" ", from: &line) | |
// Get the identifier. | |
removeIncluding(" ", from: &line) | |
removeExpectedPrefix("(", from: &line) | |
_ = extractExcluding(")", from: &line) | |
removeExpectedPrefix(")", from: &line) | |
// Get the shape. | |
removeIncluding(" ", from: &line) | |
removeExpectedPrefix("[", from: &line) | |
let shapeString = extractExcluding("]", from: &line) | |
self.dimensions = shapeString.split(separator: "x").map { repr in | |
return Int(repr)! | |
} | |
removeExpectedPrefix("]", from: &line) | |
// Get the elements. | |
self.elements = [] | |
removeIncluding(" ", from: &line) | |
while !line.starts(with: "..") { | |
let repr = extractExcluding(" ", from: &line) | |
elements.append(Float(repr)!) | |
// Remove the next blob of whitespace before repeating the loop. | |
removeIncluding(" ", from: &line) | |
} | |
} | |
} | |
struct Operation { | |
var name: String | |
var idInModel: Int | |
var dependencyID: Int | |
var wait: (Int, Int)? | |
var inputs: [Tensor] | |
var outputs: [Tensor] | |
var emit: (Int, Int)? | |
var attributes: [String] = [] | |
init( | |
lines: [Substring], | |
idInModel: Int // 0-indexed | |
) { | |
let expectedStart = "CCV_NNC_" | |
guard lines[0].starts(with: expectedStart) else { | |
startError(expectedStart, lines[0]) | |
} | |
var line: Substring = lines[0].dropFirst(expectedStart.count) | |
// Get the name. | |
name = extractExcluding(" ", from: &line) | |
let expectedSuffix = "_FORWARD" | |
precondition( | |
name.suffix(expectedSuffix.count) == expectedSuffix, | |
"Suffix '\(name.suffix(expectedSuffix.count))' != expected '_FORWARD'.") | |
name.removeLast(expectedSuffix.count) | |
// Assert that the index is what you expect. | |
removeIncluding(" ", from: &line) | |
removeExpectedPrefix("[", from: &line) | |
let idInModelRepr = extractExcluding("]", from: &line) | |
guard Int(idInModelRepr)! == idInModel + 1 else { | |
fatalError("idInModelRepr '\(idInModelRepr)' != idInModel '\(idInModel)'") | |
} | |
self.idInModel = idInModel | |
// Get the number of operands. | |
removeExpectedPrefix("]: [", from: &line) | |
let inputCount = Int(extractExcluding("]", from: &line))! | |
removeExpectedPrefix("] -> [", from: &line) | |
let outputCount = Int(extractExcluding("]", from: &line))! | |
// Get the dependency ID. | |
removeExpectedPrefix("] (", from: &line) | |
self.dependencyID = Int(extractExcluding(")", from: &line))! | |
removeExpectedPrefix(")", from: &line) | |
// Fetch the wait, if it exists. | |
var lineIndex = 1 | |
if lines[lineIndex].starts(with: "Wait: ") { | |
defer { lineIndex += 1 } | |
var line = lines[lineIndex] | |
removeExpectedPrefix("Wait: (", from: &line) | |
let wait1 = Int(extractExcluding(",", from: &line))! | |
removeExpectedPrefix(", ", from: &line) | |
let wait2 = Int(extractExcluding(")", from: &line))! | |
self.wait = (wait1, wait2) | |
} | |
// Fetch the inputs. | |
self.inputs = [] | |
for i in 0..<inputCount { | |
defer { lineIndex += 1 } | |
let line = lines[lineIndex] | |
inputs.append(Tensor(text: line, idInOperation: i, isInput: true)) | |
} | |
// Fetch the outputs. | |
self.outputs = [] | |
for i in 0..<outputCount { | |
defer { lineIndex += 1 } | |
let line = lines[lineIndex] | |
outputs.append(Tensor(text: line, idInOperation: i, isInput: false)) | |
} | |
// Fetch the emit, if it exists. | |
if lineIndex < lines.count { | |
if lines[lineIndex].starts(with: "Emit: ") { | |
defer { lineIndex += 1 } | |
var line = lines[lineIndex] | |
removeExpectedPrefix("Emit: (", from: &line) | |
let emit1 = Int(extractExcluding(",", from: &line))! | |
removeExpectedPrefix(", ", from: &line) | |
let emit2 = Int(extractExcluding(")", from: &line))! | |
self.emit = (emit1, emit2) | |
} | |
} | |
// Check that no lines remain. | |
guard lineIndex == lines.count else { | |
fatalError( | |
"'lineIndex' was '\(lineIndex)', but expected '\(lines.count)'") | |
} | |
if showSelfAttention { | |
countSelfAttention(operation: self) | |
} | |
} | |
} | |
struct Model { | |
// List of the operations, in the order they appear. | |
var operations: [Operation] = [] | |
// The number of operations that are accounted for the FLOPs estimate. | |
var accountedOperations: Int = 0 | |
// The number of operations whose utilization is simulated. | |
var simulatedOperations: Int = 0 | |
init(text: String) { | |
var lines: [Substring] = text.split(separator: "\n") | |
// Fail-safe incase you do something that causes an infinite loop. | |
var iterations: Int = 0 | |
let maxIterations: Int = 1_000_000 | |
while true { | |
defer { iterations += 1 } | |
if iterations >= maxIterations { | |
fatalError("Infinite loop occurred.") | |
} | |
// All the lines that compose the current operation. | |
var operationLines: [Substring] = [] | |
let operationRepr = lines[0] | |
assertExpectedPrefix("CCV_NNC", from: operationRepr) | |
operationLines.append(lines.removeFirst()) | |
// Keep adding lines until you encounter "CCV_NNC_" or the end message. | |
var reachedGraphEnd = false | |
var addedLineCount: Int = 0 | |
for line in lines { | |
if line.starts(with: "CCV_NNC") { | |
break | |
} | |
if line.starts(with: "Graph Stream") { | |
reachedGraphEnd = true | |
break | |
} | |
operationLines.append(line) | |
addedLineCount += 1 | |
} | |
lines.removeFirst(addedLineCount) | |
self.operations.append( | |
Operation(lines: operationLines, idInModel: operations.count)) | |
if reachedGraphEnd { | |
break | |
} | |
} | |
} | |
} | |
struct OperationStatistics { | |
var name: String | |
var occurrences: Int = 0 | |
var maxFloatOperations: Int = 0 | |
var totalFloatOperations: Int = 0 | |
// Separate actual from theoretical float ops for convolutions, to report | |
// Winograd correctly. | |
private var _actualMaxFloatOperations: Int = 0 | |
private var _actualTotalFloatOperations: Int = 0 | |
var actualMaxFloatOperations: Int { | |
_actualMaxFloatOperations == 0 | |
? maxFloatOperations | |
: _actualMaxFloatOperations | |
} | |
var actualTotalFloatOperations: Int { | |
_actualTotalFloatOperations == 0 | |
? totalFloatOperations | |
: _actualTotalFloatOperations | |
} | |
init(name: String) { | |
self.name = name | |
} | |
mutating func append(floatOperations: Int, actualForConv: Int? = nil) { | |
self.occurrences += 1 | |
self.maxFloatOperations = max(maxFloatOperations, floatOperations) | |
self.totalFloatOperations += floatOperations | |
if let actualForConv { | |
self._actualMaxFloatOperations = max( | |
_actualMaxFloatOperations, actualForConv) | |
self._actualTotalFloatOperations += actualForConv | |
} | |
} | |
func summary() -> String { | |
return """ | |
\(name) (\(largeIntegerRepr(occurrences)) instances) | |
- total float ops: \(largeIntegerRepr(totalFloatOperations)) | \ | |
max of any instance: \(largeIntegerRepr(maxFloatOperations)) | |
""" | |
} | |
} | |
var model = Model(text: gistContents) | |
// Annotate GEMMs that are part of self-attention and cross-attention. | |
for (index, operation) in model.operations.enumerated() { | |
guard operation.name == "SOFTMAX" else { | |
continue | |
} | |
guard model.operations[index - 1].name == "GEMM", | |
model.operations[index + 1].name == "GEMM" else { | |
fatalError("SOFTMAX did not occur between two GEMM operations.") | |
} | |
// Softmax always has one input and one output, with 2 dimensions total. | |
let dimensions = operation.inputs[0].dimensions | |
precondition(dimensions == operation.outputs[0].dimensions) | |
precondition(dimensions.count == 2) | |
var attribute: String | |
if dimensions[0] == dimensions[1] { | |
attribute = "self-attention" | |
} else { | |
attribute = "cross-attention" | |
} | |
model.operations[index - 1].attributes.append(attribute) | |
model.operations[index].attributes.append(attribute) | |
model.operations[index + 1].attributes.append(attribute) | |
} | |
// Annotate each operation as belonging to a specific layer. | |
// TODO: Gather per-stage statistics another day. | |
// MARK: - Gather Statistics | |
var otherGEMMStatistics = OperationStatistics(name: "OTHER GEMM") | |
var selfAttnStatistics = OperationStatistics(name: "SELF ATTENTION GEMM") | |
var crossAttnStatistics = OperationStatistics(name: "CROSS ATTENTION GEMM") | |
for operation in model.operations where operation.name == "GEMM" { | |
precondition((2...3).contains(operation.inputs.count)) | |
precondition(operation.outputs.count == 1) | |
let A_dimensions = operation.inputs[0].dimensions | |
let B_dimensions = operation.inputs[1].dimensions | |
let C_dimensions = operation.outputs[0].dimensions | |
precondition((2...4).contains(A_dimensions.count)) | |
precondition((2...4).contains(B_dimensions.count)) | |
precondition((2...4).contains(C_dimensions.count)) | |
// Inputs can be transposed and the transposing isn't reported. Take a guess | |
// at the size of the multiplication. If that's not possible, throw an error. | |
// The first two inputs should be multiplicands; the last I'm not sure about. | |
let _m_or_k = Array(A_dimensions.suffix(2).sorted()) | |
let _k_or_n = Array(B_dimensions.suffix(2).sorted()) | |
func cannotFindDimensionsError(_ msg: String) -> Never { | |
let dimensions0 = operation.inputs[0].dimensions | |
let dimensions1 = operation.inputs[1].dimensions | |
let repr = "\(dimensions0) and \(dimensions1)" | |
print("Operation:", operation.idInModel) | |
fatalError("Cannot determine M, N, and K for \(repr): \(msg)") | |
} | |
var candidateIndex0: Int | |
var candidateIndex1: Int | |
if _m_or_k[0] == _k_or_n[0] { | |
candidateIndex0 = 0 | |
candidateIndex1 = 0 | |
} else if _m_or_k[0] == _k_or_n[1] { | |
candidateIndex0 = 0 | |
candidateIndex1 = 1 | |
} else if _m_or_k[1] == _k_or_n[0] { | |
candidateIndex0 = 1 | |
candidateIndex1 = 0 | |
} else if _m_or_k[1] == _k_or_n[1] { | |
candidateIndex0 = 1 | |
candidateIndex1 = 1 | |
} else { | |
cannotFindDimensionsError("no common dimension") | |
} | |
if _m_or_k[1 - candidateIndex0] == _k_or_n[1 - candidateIndex1] { | |
let outputDimensions = Array(C_dimensions.suffix(2)) | |
if outputDimensions[0] == _m_or_k[candidateIndex0], | |
outputDimensions[1] == _k_or_n[candidateIndex1] { | |
candidateIndex0 = 1 - candidateIndex0 | |
candidateIndex1 = 1 - candidateIndex1 | |
} else if outputDimensions[0] == _m_or_k[1 - candidateIndex0], | |
outputDimensions[1] == _k_or_n[1 - candidateIndex1] { | |
// pass | |
} else { | |
cannotFindDimensionsError("two common dimensions, mismatched output") | |
} | |
precondition(_m_or_k[candidateIndex0] == _k_or_n[candidateIndex1]) | |
} | |
let M = _m_or_k[1 - candidateIndex0] | |
let N = _k_or_n[1 - candidateIndex1] | |
let K = _m_or_k[candidateIndex0] | |
// Next, factor in the number of multiplications for batched matmul. The log | |
// seems to have batched the execution of two images. | |
if A_dimensions.count == 4 || B_dimensions.count == 4 { | |
precondition(A_dimensions.count == 4 && B_dimensions.count == 4) | |
precondition(A_dimensions.prefix(2) == B_dimensions.prefix(2)) | |
} | |
if A_dimensions.count >= 3 || B_dimensions.count >= 3 { | |
precondition(A_dimensions.count == C_dimensions.count) | |
} | |
if A_dimensions.count >= 3 && B_dimensions.count >= 3 { | |
// If both inputs have the same batch dimension, we can just check A. | |
precondition( | |
A_dimensions[0] == B_dimensions[0], | |
"\(A_dimensions[0]) != \(B_dimensions[0])") | |
} | |
var batchSize: Int | |
if A_dimensions.count >= 3 { | |
batchSize = A_dimensions[0] | |
} else if B_dimensions.count >= 3 { | |
batchSize = B_dimensions[0] | |
} else { | |
batchSize = 1 | |
} | |
if forceBatch1 { | |
precondition(batchSize == 1 || batchSize == expectedBatchSize) | |
batchSize = 1 | |
} | |
let extraDimension = (A_dimensions.count == 4) ? A_dimensions[1] : 1 | |
var floatOperations = batchSize * extraDimension * 2 * M * N * K | |
if simulateUtilization { | |
// let size = SIMD3<Int>(M, N, K) | |
var isTransposed = false | |
if (M == N && N == K) { | |
fatalError("Cannot tell which is transposed.") | |
} | |
let logIsTransposed = true | |
let A_K = A_dimensions[A_dimensions.count - 1] | |
let B_K = B_dimensions[B_dimensions.count - 2] | |
if K != A_K || K != B_K { | |
isTransposed = true | |
if logIsTransposed { | |
// print("Transposed: \(C_dimensions) = \(A_dimensions) x \(B_dimensions)") | |
} | |
} else { | |
if logIsTransposed { | |
// print("Regular: \(C_dimensions) = \(A_dimensions) x \(B_dimensions)") | |
} | |
} | |
let isBatched = (batchSize != 1) || (extraDimension != 1) | |
if isBatched { | |
// print("Batched: \(C_dimensions) = \(A_dimensions) x \(B_dimensions)") | |
} | |
// print() | |
floatOperations = simulateMatrixUtilization( | |
// size: size, | |
shape: MatrixShape(operation: operation), | |
floatOps: floatOperations, | |
isBatched: isBatched, | |
isTransposed: isTransposed, | |
model: &model) | |
} | |
if operation.attributes.contains("self-attention") { | |
if operation.attributes.contains("cross-attention") { | |
fatalError("Operation was tagged as both attention types.") | |
} | |
selfAttnStatistics.append(floatOperations: floatOperations) | |
} else if operation.attributes.contains("cross-attention") { | |
crossAttnStatistics.append(floatOperations: floatOperations) | |
} else { | |
otherGEMMStatistics.append(floatOperations: floatOperations) | |
} | |
} | |
var conv1x1Statistics = OperationStatistics(name: "CONVOLUTION (1x1)") | |
var conv3x3Statistics = OperationStatistics(name: "CONVOLUTION (3x3)") | |
for operation in model.operations where operation.name == "CONVOLUTION" { | |
// I don't know what the third input is, but I don't need it. | |
precondition(operation.inputs.count == 3) | |
precondition(operation.outputs.count == 1) | |
let I_dimensions = operation.inputs[0].dimensions // input | |
let F_dimensions = operation.inputs[1].dimensions // filter | |
let O_dimensions = operation.outputs[0].dimensions // output | |
precondition(I_dimensions.count == 4) // first dim = batch size | |
precondition(F_dimensions.count == 4) // first dim = number of filters | |
precondition(O_dimensions.count == 4) // first dim = batch size | |
precondition(I_dimensions[0] == O_dimensions[0]) | |
precondition(I_dimensions[3] == F_dimensions[1]) | |
precondition(F_dimensions[0] == O_dimensions[3]) | |
enum FilterType { | |
case conv1x1 | |
case conv3x3 | |
var parameters: Int { | |
self == .conv1x1 ? 1 : 9 | |
} | |
} | |
var filter: FilterType | |
if (F_dimensions[2], F_dimensions[3]) == (1, 1) { | |
filter = .conv1x1 | |
} else if (F_dimensions[2], F_dimensions[3]) == (3, 3) { | |
filter = .conv3x3 | |
} else { | |
fatalError("Unrecognized convolution type.") | |
} | |
// Calculate the number of float operations. | |
var batchSize = I_dimensions[0] | |
if forceBatch1 { | |
precondition(batchSize == 1 || batchSize == expectedBatchSize) | |
batchSize = 1 | |
} | |
// FLOPs based on output image size, ignoring input image size | |
let pixels = O_dimensions[1] * O_dimensions[2] | |
// The cost to apply a single filter to a single pixel. | |
let filterCost = 2 * I_dimensions[3] * filter.parameters | |
// The cost to generate one plane of the output image. | |
let planeCost = pixels * filterCost | |
// The cost of the entire operation. | |
let planeCount = O_dimensions[3] | |
let floatOperations = batchSize * planeCount * planeCost | |
var actualOperations: Int? | |
if simulateUtilization { | |
let shape = ConvolutionShape(operation: operation) | |
actualOperations = simulateConvUtilization(shape: shape) | |
} | |
switch filter { | |
case .conv1x1: | |
conv1x1Statistics.append( | |
floatOperations: floatOperations, actualForConv: actualOperations) | |
case .conv3x3: | |
conv3x3Statistics.append( | |
floatOperations: floatOperations, actualForConv: actualOperations) | |
} | |
} | |
// After convolutions, measure cost of softmax, swish, and normalization. | |
var softmaxStatistics = OperationStatistics(name: "SOFTMAX") | |
for operation in model.operations where operation.name == "SOFTMAX" { | |
precondition(operation.inputs.count == 1) | |
precondition(operation.outputs.count == 1) | |
let I_dimensions = operation.inputs[0].dimensions | |
let O_dimensions = operation.outputs[0].dimensions | |
precondition(I_dimensions == O_dimensions) | |
precondition(I_dimensions.count == 2) | |
// Analysis of softmax compute cost: | |
// - max = CMPSEL, 2 operations | |
// +2 float operations | |
// - subtraction cannot fuse with another multiply; effectively 2 operations | |
// +2 effective float operations | |
// - exponentiation takes one GPU core-cycle to issue, but theoretically not | |
// a throughput bottleneck (consumes 4 cycles or 8 operations, running | |
// concurrently with F32 FFMA) | |
// +2 effective float operations | |
// - summation, just like subtraction: counted as 2 operations | |
// +2 effective float operations | |
// - multiplication by reciprocal: cannot be fused with another addition or | |
// subtraction; count as 2 operations | |
// +2 effective float operations | |
// Only count the O(n) parts of the compute cost, not the O(1) parts. | |
// | |
// Total: 10 x number of elements | |
var floatOperations = 10 * I_dimensions[0] * I_dimensions[1] | |
// Softmax for cross-attention has an artifact where the first dimension gets | |
// multiplied by the batch size, instead of adding a third dimension. | |
if forceBatch1 { | |
floatOperations /= expectedBatchSize | |
} | |
if simulateUtilization { | |
precondition(I_dimensions.count == 2, "Unexpected dimension.") | |
let shape = SIMD2<Int>(I_dimensions[0], I_dimensions[1]) | |
floatOperations = simulateSoftmaxUtilization(shape: shape) | |
} | |
softmaxStatistics.append(floatOperations: floatOperations) | |
} | |
var normalizationStatistics = OperationStatistics(name: "NORMALIZATION") | |
func isNormalization(_ operation: Operation) -> Bool { | |
if operation.name == "LAYER_NORM" { | |
return true | |
} else if operation.name == "GROUP_NORM" { | |
return true | |
} else { | |
precondition(!operation.name.contains("NORM")) | |
return false | |
} | |
} | |
for operation in model.operations where isNormalization(operation) { | |
precondition(operation.inputs.count == 3) | |
precondition(operation.outputs.count == 3) | |
enum NormalizationType { | |
case group | |
case layer | |
} | |
var normalization: NormalizationType | |
if operation.name == "LAYER_NORM" { | |
normalization = .layer | |
} else { | |
normalization = .group | |
} | |
let I_dimensions = operation.inputs[0].dimensions | |
let O_dimensions = operation.outputs[0].dimensions | |
precondition(I_dimensions == O_dimensions) | |
if normalization == .group { | |
precondition(I_dimensions.count == 4) | |
precondition(O_dimensions.count == 4) | |
} else { | |
precondition(I_dimensions.count == 3) | |
precondition(O_dimensions.count == 3) | |
} | |
var batchSize = I_dimensions[0] | |
if forceBatch1 { | |
precondition(batchSize == 1 || batchSize == expectedBatchSize) | |
batchSize = 1 | |
} | |
let elements = I_dimensions[1...].reduce(1, *) | |
// Analysis of normalization compute cost: | |
// - addition to sum the elements, doesn't matter whether the multiplication | |
// is fused with the addition (O(n)) or occurs afterward (O(1)) | |
// +2 effective float operations | |
// - subtraction from the mean, not fusable with any multiplication. Assume | |
// the result of this subtraction is cached inside registers. | |
// +2 effective float operations | |
// - square the difference from the mean and fuse with the reduction | |
// +2 float operations | |
// - multiply the cached difference with the mean by rsqrt(variance) | |
// +2 effective float operations | |
// Always underestimate the number of operations, because we're providing a | |
// definitive **lower** bound. Reference: | |
// https://www.pinecone.io/learn/batch-layer-normalization/ | |
// | |
// Total: 8 x number of elements | |
var floatOperations = 8 * elements | |
if !simulateUtilization { | |
normalizationStatistics.append(floatOperations: floatOperations) | |
continue | |
} | |
// Assume 50% bandwidth utilization with FP16, 75% with FP32. Assume the SLC | |
// does not have 800 GB/s of bidirectional bandwidth. | |
let utilization: Float = (simulationPrecision == .f32) ? 0.75 : 0.50 | |
let bandwidth = slcBandwidthForLatency * utilization | |
let bytesPerElement = (simulationPrecision == .f32) ? 4 : 2 | |
// To simulate utilization, instead go by latency (MPS only) and bandwidth. | |
// MFA only consumes bandwidth to shuffle the input and output through the | |
// system-level cache. MPS has to shuffle every intermediate result that isn't | |
// elementwise. | |
if simulationFramework == .mfa && !useMonolithicKernel { | |
let bytesTransferred = 2 * elements * bytesPerElement | |
let memoryLatency = Float(bytesTransferred) / bandwidth | |
let memoryVirtualOps = Int(flopsForLatency * memoryLatency) | |
floatOperations = max(floatOperations, memoryVirtualOps) | |
} else { | |
// Lower bound dictated by latency. | |
floatOperations = max(floatOperations, Int(flopsForLatency * 58e-6)) | |
let bytesTransferred = 6 * elements * bytesPerElement | |
let memoryLatency = Float(bytesTransferred) / bandwidth | |
let memoryVirtualOps = Int(flopsForLatency * memoryLatency) | |
floatOperations = max(floatOperations, memoryVirtualOps) | |
} | |
normalizationStatistics.append(floatOperations: floatOperations) | |
model.simulatedOperations += 1 | |
} | |
var activationStatistics = OperationStatistics(name: "ACTIVATION") | |
func isActivation(_ operation: Operation) -> Bool { | |
operation.name == "SWISH" || operation.name == "GELU" | |
} | |
// MARK: - Display Data | |
let allStatistics = [ | |
otherGEMMStatistics, selfAttnStatistics, crossAttnStatistics, | |
conv1x1Statistics, conv3x3Statistics, | |
softmaxStatistics, normalizationStatistics | |
].sorted(by: { $0.actualTotalFloatOperations > $1.actualTotalFloatOperations }) | |
let totalOccurrences = model.operations.count | |
var accounted = allStatistics.map(\.occurrences).reduce(0, +) | |
// Model not accounted-for operations as having 58 microseconds of latency. You | |
// need to manually add these to: | |
// - "operations accounted for" | |
// - "operations simulated" | |
// - "distribution" | |
let numTransposes = model.operations.reduce(into: 0) { | |
if $1.name == "TRANSPOSE" { $0 += 1 } | |
} | |
let numScalarMul = model.operations.reduce(into: 0) { | |
if $1.name == "SCALAR_MUL" { $0 += 1 } | |
} | |
let numAdd = model.operations.reduce(into: 0) { | |
if $1.name == "ADD" { $0 += 1 } | |
} | |
let numGELU = model.operations.reduce(into: 0) { | |
if $1.name == "GELU" { $0 += 1 } | |
} | |
let numSwish = model.operations.reduce(into: 0) { | |
if $1.name == "SWISH" { $0 += 1 } | |
} | |
let numElementwise = numScalarMul + numAdd + numGELU + numSwish | |
let stableDiffusionIters: Float = 30 | |
var transposeLatencyPerIter = Float(numTransposes) * 58e-6 | |
var transposeLatencyOverall = transposeLatencyPerIter * stableDiffusionIters | |
var elementwiseLatencyPerIter = Float(numElementwise) * 58e-6 | |
var elementwiseLatencyOverall = elementwiseLatencyPerIter * stableDiffusionIters | |
accounted += numTransposes | |
accounted += numElementwise | |
// Zero out the transpose and elementwise latency for theoretical estimates. | |
// Also, zero when simulating FlashAttention + MFA fused activations. | |
let mfaNonMonolithic = (simulationFramework == .mfa && !useMonolithicKernel) | |
if !simulateUtilization || mfaNonMonolithic { | |
transposeLatencyPerIter = 0 | |
transposeLatencyOverall = 0 | |
elementwiseLatencyPerIter = 0 | |
elementwiseLatencyOverall = 0 | |
} | |
// Next, elementwise operations. | |
// After this, the last operation is group norm. Set Swift to release mode | |
// again to minimize Metal API latency. | |
func getSimulatedOperations() -> Int { | |
var ops = model.simulatedOperations | |
ops += numTransposes | |
ops += numElementwise | |
return ops | |
} | |
let missing = totalOccurrences - accounted | |
let totalFloatOperations = allStatistics | |
.map(\.totalFloatOperations).reduce(0, +) | |
let actualTotalFloatOperations = allStatistics | |
.map(\.actualTotalFloatOperations).reduce(0, +) | |
let actualConvRepr = (cleanDisplay ? "" : "Actual: ") | |
var convRepr = simulateUtilization ? actualConvRepr : "ConvGEMM (1x): " | |
let convRepr2 = simulateUtilization ? "Actual" : "ConvGEMM" | |
if cleanDisplay { | |
// Force convRepr to empty. | |
convRepr = "" | |
} | |
if !cleanDisplay { | |
print() | |
print("Overview:") | |
print(""" | |
- total operations: \(largeIntegerRepr(totalOccurrences)) | |
- factored into estimate: \(largeIntegerRepr(accounted)) | |
- not accounted for: \(largeIntegerRepr(missing)) | |
- float ops\(simulateUtilization ? " (effective)" : ""): | |
- \(convRepr)\(largeIntegerRepr(totalFloatOperations)) | |
""") | |
} | |
// 1 - 1 / 2.25 = 5 / 9 | |
var winogradReduction = conv3x3Statistics.totalFloatOperations * Int(5) / 9 | |
var minFloatOperations = totalFloatOperations - winogradReduction | |
// Shorten "2.25x" to "2x" for legibility | |
if !cleanDisplay { | |
print(" - Winograd (2x): \(largeIntegerRepr(minFloatOperations))") | |
} | |
// 1 - 1 / 4 = 3 / 4 | |
winogradReduction = conv3x3Statistics.totalFloatOperations * Int(3) / 4 | |
minFloatOperations = totalFloatOperations - winogradReduction | |
if !cleanDisplay { | |
print(" - Winograd (4x): \(largeIntegerRepr(minFloatOperations))") | |
} | |
// 1 - 1 / 9 = 8 / 9 | |
winogradReduction = conv3x3Statistics.totalFloatOperations * Int(8) / 9 | |
minFloatOperations = totalFloatOperations - winogradReduction | |
if !cleanDisplay { | |
print(" - Winograd (9x): \(largeIntegerRepr(minFloatOperations))") | |
} | |
if showLatency { | |
let iters = stableDiffusionIters | |
var secondsConvGEMM = iters * Float(actualTotalFloatOperations) | |
var secondsWinograd = iters * Float(minFloatOperations) | |
secondsConvGEMM /= flopsForLatency | |
secondsWinograd /= flopsForLatency | |
func latency(_ time: Float) -> String { | |
var latency = time | |
latency += transposeLatencyOverall | |
latency += elementwiseLatencyOverall | |
if showItersPerSecond { | |
var repr = "" | |
let secondsPerIter = latency / iters | |
repr += String(format: "%.3f", secondsPerIter) | |
repr += " s/it, " | |
let itersPerSecond = iters / latency | |
repr += String(format: "%.2f", itersPerSecond) | |
repr += " it/s, " | |
repr += String(format: "%.2f", latency) | |
repr += " s" | |
return repr | |
} else { | |
return latencyRepr(latency) | |
} | |
} | |
let simulatedLabel = simulateUtilization ? "simulated" : "theoretical" | |
print(""" | |
- \(simulatedLabel) latency @ 30 steps: | |
""") | |
if !(cleanDisplay && cleanDisplayWinograd) { | |
print(""" | |
- \(convRepr)\(latency(secondsConvGEMM)) | |
""") | |
} | |
if (cleanDisplay && cleanDisplayWinograd) { | |
// print(""" | |
// - Winograd (9x): \(latency(secondsWinograd)) | |
// """) | |
print(""" | |
- \(latency(secondsWinograd)) | |
""") | |
} | |
} | |
if simulateUtilization { | |
let repr = "\(getSimulatedOperations()) / \(model.operations.count)" | |
var addendum = "" | |
if simulationFramework == .mfa, | |
useWinogradMPS { | |
if useMonolithicKernel { | |
addendum = " (monolithic GEMM, MPS Conv2D)" | |
} else { | |
addendum = " (FlashAttention, MPS Conv2D)" | |
} | |
} else if simulationFramework == .mfa { | |
if useMonolithicKernel { | |
// addendum = " | |
// addendum = " (monolithic GEMM, MFA Winograd 2x)" | |
addendum = " (INVALID PARAMETER COMBINATION)" | |
} else { | |
addendum = " (FlashAttention, MFA Winograd)" | |
} | |
} | |
print(""" | |
- simulation config: | |
- framework: \(simulationFramework.repr)\(addendum) | |
- precision: \(simulationPrecision.repr) | |
- operations simulated: \(repr) | |
""") | |
} | |
func printStatistics( | |
_ statisticsArray: [OperationStatistics], | |
isWinograd: Bool | |
) { | |
var contributions: [(Float, String)] = [] | |
for statistics in statisticsArray { | |
var thisOperations: Int | |
if isWinograd { | |
thisOperations = statistics.totalFloatOperations | |
} else { | |
thisOperations = statistics.actualTotalFloatOperations | |
} | |
var latency = Float(thisOperations) * stableDiffusionIters | |
latency /= flopsForLatency | |
contributions.append((latency, statistics.name)) | |
} | |
// Next, account for latency of sequential operations. | |
contributions.append((transposeLatencyOverall, "TRANSPOSE")) | |
contributions.append((elementwiseLatencyOverall, "ELEMENTWISE")) | |
contributions.sort(by: { $0.0 > $1.0 }) | |
let totalLatency = contributions.reduce(into: 0) { $0 += $1.0 } | |
if cleanDisplay { | |
print("| Distribution | Latency | Operation Class |") | |
print("| ------------ | ------- | --------------- |") | |
} | |
for contribution in contributions { | |
let proportion = contribution.0 / totalLatency | |
var thousandths = Int(rint(proportion * 1000)) | |
var percent: String = "" | |
if thousandths / 100 >= 1 { | |
percent += String(thousandths / 100) | |
thousandths %= 100 | |
} | |
if thousandths / 10 >= 1 { | |
percent += String(thousandths / 10) | |
thousandths %= 10 | |
} else { | |
percent += "0" | |
} | |
let latency = latencyRepr(contribution.0) | |
percent += ".\(thousandths)%" | |
if cleanDisplay { | |
print("| \(percent) | \(latency) | \(contribution.1) |") | |
} else { | |
print("(\(percent) - \(latency)) \(contribution.1)") | |
} | |
} | |
} | |
if !(cleanDisplay && cleanDisplayWinograd) { | |
print() | |
if !cleanDisplay { | |
print("Distribution (\(convRepr2)):") | |
} | |
printStatistics(allStatistics, isWinograd: false) | |
} | |
if (cleanDisplay && cleanDisplayWinograd) { | |
print() | |
print("Distribution (Winograd 9x):") | |
var winogradStatistics = allStatistics | |
do { | |
let conv3x3Index = winogradStatistics.firstIndex(where: { | |
$0.name == "CONVOLUTION (3x3)" | |
})! | |
winogradStatistics[conv3x3Index].totalFloatOperations -= winogradReduction | |
winogradStatistics.sort(by: { | |
$0.totalFloatOperations > $1.totalFloatOperations | |
}) | |
} | |
printStatistics(winogradStatistics, isWinograd: true) | |
} | |
if showExactComputeCost { | |
print() | |
print("Compute cost of each operation type, sorted from greatest to least.") | |
for statistics in allStatistics { | |
print() | |
print(statistics.summary()) | |
} | |
} | |
// Put some separation between the text and the error text Xcode always appends | |
// to the output. | |
if showSelfAttention { | |
printSelfAttention() | |
} | |
// MARK: - Softmax Benchmarks | |
if benchmarkMPSMFAGEMM { | |
profileGEMM(operations: model.operations) | |
} | |
if benchmarkMPSSoftmax { | |
profileSoftmax(simulationPrecision: simulationPrecision) | |
} | |
if benchmarkMPSTranspose { | |
profileTranspose( | |
operations: model.operations, simulationPrecision: simulationPrecision) | |
} | |
// MARK: - Convolution Benchmarks | |
// Is it worthwhile to port Winograd to MFA? | |
// | |
// Enumerate all the image sizes within Stable Diffusion. Then, repeat with each | |
// reasonable hyperparameter going all the way from -64 to +64. Print to CSV so | |
// I can graph in Google Sheets. | |
if !benchmarkMPSConvolution { | |
exit(0) | |
} | |
let convDataType: MPSDataType = (simulationPrecision == .f32) | |
? .float32 : .float16 | |
struct ConvolutionShape: Equatable, Hashable { | |
var data: [Int] // [B, H_IN, W_IN, C_IN] | |
var filter: [Int] // [C_OUT, C_IN, K, K] | |
var output: [Int] // [B, H_OUT, W_OUT, C_OUT] | |
var channelsIn: Int | |
var channelsOut: Int | |
var window: Int | |
var batch: Int | |
var heightIn: Int | |
var widthIn: Int | |
var heightOut: Int | |
var widthOut: Int | |
var differentInOut: Bool | |
init(operation: Operation) { | |
precondition(operation.name == "CONVOLUTION") | |
self.init( | |
data: operation.inputs[0].dimensions, | |
filter: operation.inputs[1].dimensions, | |
output: operation.outputs[0].dimensions) | |
} | |
init(data: [Int], filter: [Int], output: [Int]) { | |
self.data = data | |
self.filter = filter | |
self.output = output | |
precondition(data[3] == filter[1]) | |
self.channelsIn = data[3] | |
precondition(filter[0] == output[3]) | |
self.channelsOut = filter[0] | |
precondition(filter[2] == filter[3]) | |
self.window = filter[2] | |
precondition(data[0] == output[0]) | |
self.batch = data[0] | |
self.heightIn = data[1] | |
self.widthIn = data[2] | |
self.heightOut = output[1] | |
self.widthOut = output[2] | |
if heightIn != heightOut || widthIn != widthOut { | |
self.differentInOut = true | |
} else{ | |
self.differentInOut = false | |
} | |
} | |
func repr() -> String { | |
// K x K x C_IN x H_OUT x W_OUT x C_OUT | |
// https://www.kaggle.com/general/240788 | |
var out = "\(batch) x \(window) x \(window) x \(channelsIn)" | |
out += " x \(heightOut) x \(widthOut) x \(channelsOut)" | |
if heightIn != heightOut || widthIn != widthOut { | |
precondition(heightIn != heightOut && widthIn != widthOut ) | |
precondition(heightIn == 2 * heightOut) | |
precondition(widthIn == 2 * widthOut) | |
out += " - stride 2" | |
} | |
return out | |
} | |
func floatOperations() -> Int { | |
var out = 2 * batch * window * window * channelsIn | |
out *= heightOut * widthOut * channelsOut | |
return out | |
} | |
} | |
var shapes: [ConvolutionShape: Int] = [:] | |
for operation in model.operations where operation.name == "CONVOLUTION" { | |
let shape = ConvolutionShape( | |
data: operation.inputs[0].dimensions, | |
filter: operation.inputs[1].dimensions, | |
output: operation.outputs[0].dimensions) | |
if shapes[shape] == nil { | |
shapes[shape] = 1 | |
} else { | |
shapes[shape]! += 1 | |
} | |
} | |
var keysAndValues = zip(shapes.keys, shapes.values).map { | |
(k: $0, v: $1) | |
}.sorted { | |
let ops0 = $0.k.floatOperations() | |
let ops1 = $1.k.floatOperations() | |
let impact0 = ops0 * $0.v | |
let impact1 = ops1 * $1.v | |
if impact0 > impact1 { | |
return true | |
} else if impact0 < impact1 { | |
return false | |
} else if ops0 > ops1 { | |
return true | |
} else if ops0 < ops1 { | |
return false | |
} else { | |
let repr1 = $0.k.repr() | |
let repr2 = $1.k.repr() | |
return strcmp(repr1, repr2) > 0 | |
} | |
} | |
_ = keysAndValues as [(ConvolutionShape, Int)] | |
#if false | |
let cutoff: Double = 1.10 | |
keysAndValues = Array(keysAndValues[0..<3]) | |
// Increase the computation time to measure power consumption. | |
let powerTimeAmplification = 1 | |
keysAndValues.append(( | |
ConvolutionShape( | |
data: [1 * powerTimeAmplification, 512, 512, 256], | |
filter: [128, 256, 3, 3], | |
output: [1 * powerTimeAmplification, 512, 512, 128]), | |
0)) | |
keysAndValues.append(( | |
ConvolutionShape( | |
data: [2 * powerTimeAmplification, 128, 128, 1024], | |
filter: [256, 1024, 3, 3], | |
output: [2 * powerTimeAmplification, 128, 128, 256]), | |
0)) | |
#else | |
// Set to 1.10 so we can gather ALU utilization statistics for all operations in | |
// Stable Diffusion. | |
let cutoff: Double = 1.10 // 0.90 | |
#endif | |
let totalImpact = keysAndValues.map { (shape, count) in | |
count * shape.floatOperations() | |
}.reduce(0, +) | |
let totalOps = keysAndValues.map(\.v).reduce(0, +) | |
let totalShapes = keysAndValues.count | |
var accumulatedImpact = 0 | |
var accumulatedOps = 0 | |
var accumulatedShapes = 0 | |
print() | |
print("Total conv ops: \(totalOps) dispatches / \(totalShapes) shapes") | |
for i in 0..<2 { | |
if accumulatedImpact != 0 && cutoff < 1.0 { | |
let repr = "\(accumulatedOps) dispatches / \(accumulatedShapes) shapes" | |
print("90% of compute: \(repr)") | |
} | |
accumulatedImpact = 0 | |
accumulatedOps = 0 | |
accumulatedShapes = 0 | |
for (shape, count) in keysAndValues { | |
let impactRepr = largeIntegerRepr(count * shape.floatOperations()) | |
let opsRepr = largeIntegerRepr(shape.floatOperations()) | |
if i == 1 { | |
print("\(impactRepr) / \(opsRepr) - \(count) x (\(shape.repr()))") | |
} | |
accumulatedImpact += count * shape.floatOperations() | |
accumulatedOps += count | |
accumulatedShapes += 1 | |
if Double(accumulatedImpact) > cutoff * Double(totalImpact) { | |
if i == 1 { | |
print("...") | |
} | |
break | |
} | |
} | |
} | |
// MARK: - Profile using MPSGraph | |
// Bypass a Swift runtime crash. | |
struct Graph { | |
var graph: MPSGraph = MPSGraph() | |
var executable: MPSGraphExecutable = MPSGraphExecutable() | |
var compiled: Bool = false | |
var inputData: [MPSGraphTensorData] = [] | |
var inputMTLData: [MTLBuffer] = [] | |
var inputTensors: [MPSGraphTensor] = [] | |
var outputData: [MPSGraphTensorData] = [] | |
var outputMTLData: [MTLBuffer] = [] | |
var outputTensors: [MPSGraphTensor] = [] | |
} | |
let numShapes = accumulatedShapes | |
var graphs: [Graph] = [] | |
for _ in 0..<numShapes { | |
graphs.append(Graph()) | |
} | |
let device = MTLCopyAllDevices().first! | |
let mpsGraphDevice = MPSGraphDevice(mtlDevice: device) | |
let graphsQueue = DispatchQueue( | |
label: "com.philipturner.CalculateDiffusion.graphs") | |
// Compile on 8 CPU cores to reduce latency. | |
let numCores = 1 //8) wtf 1 thread is faster than 2 or 8 | |
DispatchQueue.concurrentPerform(iterations: numCores) { z in | |
var opIndex = z | |
while opIndex < numShapes { | |
defer { opIndex += numCores } | |
let graph = graphsQueue.sync { graphs[opIndex].graph } | |
let shape = keysAndValues[opIndex].k | |
let opDesc = MPSGraphConvolution2DOpDescriptor() | |
opDesc.dataLayout = .NHWC | |
opDesc.weightsLayout = .OIHW | |
opDesc.dilationRateInX = 1 | |
opDesc.dilationRateInY = 1 | |
opDesc.groups = 1 | |
opDesc.strideInX = 1 | |
opDesc.strideInY = 1 | |
if cutoff < 0.95 { | |
// None of the ops in the upper 90% have strides. | |
precondition(shape.repr().contains("stride") == false) | |
} else { | |
if shape.repr().contains("stride") { | |
opDesc.strideInX = 2 | |
opDesc.strideInY = 2 | |
} | |
} | |
let type: MPSDataType = convDataType | |
let source = graph.placeholder( | |
shape: shape.data.map(NSNumber.init), dataType: type, name: "data") | |
let weights = graph.placeholder( | |
shape: shape.filter.map(NSNumber.init), dataType: type, name: "filter") | |
let outputs = graph.convolution2D( | |
source, weights: weights, descriptor: opDesc, name: "convolution") | |
let tensors = [source, weights, outputs] | |
let tensorsShape = [shape.data, shape.filter, shape.output] | |
let dataSize = (convDataType == .float32) ? 4 : 2 | |
let tensorData = zip(tensors, tensorsShape).map { | |
let numBytes = $1.reduce(1, *) * dataSize | |
let emptyData = Data(count: numBytes) | |
return MPSGraphTensorData( | |
device: mpsGraphDevice, data: emptyData, shape: $1.map(NSNumber.init), | |
dataType: convDataType) | |
} | |
let compileDesc = MPSGraphCompilationDescriptor() | |
if mpsGraphOptimizationLevel1 { | |
compileDesc.optimizationLevel = .level1 | |
} else { | |
compileDesc.optimizationLevel = .level0 | |
} | |
compileDesc.waitForCompilationCompletion = true | |
var feeds: [MPSGraphTensor : MPSGraphShapedType] = [:] | |
for (tensor, data) in zip(tensors, tensorData) { | |
feeds[tensor] = MPSGraphShapedType( | |
shape: data.shape, dataType: convDataType) | |
} | |
feeds[outputs] = nil | |
let executable = graph.compile( | |
with: mpsGraphDevice, feeds: feeds, targetTensors: [tensors[2]], | |
targetOperations: nil, compilationDescriptor: compileDesc) | |
graphsQueue.sync { | |
graphs[opIndex].executable = executable | |
graphs[opIndex].compiled = true | |
graphs[opIndex].inputData = [tensorData[0], tensorData[1]] | |
graphs[opIndex].inputTensors = [tensors[0], tensors[1]] | |
graphs[opIndex].outputData = [tensorData[2]] | |
graphs[opIndex].outputTensors = [tensors[2]] | |
} | |
} | |
} | |
let commandQueue = device.makeCommandQueue()! | |
func dispatch(graph: Graph, sync: Bool) { | |
let inputs = graph.inputData | |
let results = graph.outputData | |
for tensorData in inputs + results { | |
precondition(tensorData.dataType == convDataType) | |
} | |
let desc = MPSGraphExecutableExecutionDescriptor() | |
desc.waitUntilCompleted = sync | |
let mpsCommandBuffer = MPSCommandBuffer(from: commandQueue) | |
graph.executable.encode( | |
to: mpsCommandBuffer, inputs: inputs, results: results, | |
executionDescriptor: nil) | |
mpsCommandBuffer.commit() | |
if sync { | |
mpsCommandBuffer.waitUntilCompleted() | |
} | |
} | |
// Run through the entire data set once to warm up. Then, run 5 warmup | |
// iterations, and finally profile, with 16 iterations per trial, and | |
// 4 trials total. We don't have control over the encoding process so we have to | |
// go by CPU -> GPU -> CPU latency. Dividing by 16 should amortize this a bit. | |
do { | |
let start = CACurrentMediaTime() | |
for graph in graphs[0..<graphs.count - 1] { | |
dispatch(graph: graph, sync: true) | |
} | |
dispatch(graph: graphs.last!, sync: true) | |
let end = CACurrentMediaTime() | |
print() | |
print("Startup latency: \(latencyRepr(end - start))") | |
} | |
// Amortize the execution latency over 'iterations' duplicates of the command. | |
func profile(iterations: Int, trials: Int) { | |
print() | |
print("MPSGraph Optimization Level: \(mpsGraphOptimizationLevel1 ? 1 : 0)") | |
print("Latency - GFLOPS / Theoretical - Convolution2D Shape") | |
for (i, graph) in graphs.enumerated() { | |
var minTime = Double.infinity | |
for _ in 0..<trials { | |
let start = CACurrentMediaTime() | |
for _ in 0..<iterations - 1 { | |
dispatch(graph: graph, sync: false) | |
} | |
dispatch(graph: graph, sync: true) | |
let end = CACurrentMediaTime() | |
let time = (end - start) / Double(iterations) | |
minTime = min(minTime, time) | |
} | |
var repr = latencyRepr(minTime) | |
let flops = Double(keysAndValues[i].k.floatOperations()) / minTime | |
let percent = 100 * flops / Double(flopsForLatency) | |
repr += " - " + String(format: "%.1f", percent) + "%" | |
repr += " - " + keysAndValues[i].k.repr() | |
print(repr) | |
} | |
} | |
profile(iterations: 5, trials: 1) | |
profile(iterations: 16, trials: 4) | |
func generateConvSpeeds() -> [ConvolutionShape: MatrixUtilization] { | |
// It might be a good idea to eventually abstract this away into a file. | |
let rawData_f16 = """ | |
1.5 ms - 95.4% - 2 x 3 x 3 x 320 x 64 x 64 x 320 | |
1.7 ms - 84.3% - 2 x 3 x 3 x 1280 x 16 x 16 x 1280 | |
1.7 ms - 85.3% - 2 x 3 x 3 x 640 x 32 x 32 x 640 | |
5.7 ms - 100.3% - 2 x 3 x 3 x 640 x 64 x 64 x 640 | |
6.4 ms - 89.4% - 2 x 3 x 3 x 1280 x 32 x 32 x 1280 | |
2.9 ms - 97.0% - 2 x 3 x 3 x 640 x 64 x 64 x 320 | |
5.0 ms - 56.6% - 2 x 3 x 3 x 2560 x 16 x 16 x 1280 | |
4.4 ms - 98.1% - 2 x 3 x 3 x 960 x 64 x 64 x 320 | |
7.5 ms - 57.2% - 2 x 3 x 3 x 1920 x 32 x 32 x 640 | |
671 µs - 53.0% - 2 x 3 x 3 x 1280 x 8 x 8 x 1280 | |
3.3 ms - 85.7% - 2 x 3 x 3 x 1280 x 32 x 32 x 640 | |
2.5 ms - 85.8% - 2 x 3 x 3 x 960 x 32 x 32 x 640 | |
3.4 ms - 62.1% - 2 x 3 x 3 x 1920 x 16 x 16 x 1280 | |
2.1 ms - 34.5% - 2 x 3 x 3 x 2560 x 8 x 8 x 1280 | |
302 µs - 52.2% - 2 x 1 x 1 x 640 x 32 x 32 x 640 | |
300 µs - 52.7% - 2 x 1 x 1 x 320 x 64 x 64 x 320 | |
352 µs - 45.0% - 2 x 1 x 1 x 1280 x 16 x 16 x 1280 | |
847 µs - 83.9% - 2 x 3 x 3 x 640 x 16 x 16 x 1280 | |
863 µs - 82.4% - 2 x 3 x 3 x 320 x 32 x 32 x 640 | |
505 µs - 62.6% - 2 x 1 x 1 x 640 x 64 x 64 x 320 | |
707 µs - 44.7% - 2 x 1 x 1 x 2560 x 16 x 16 x 1280 | |
710 µs - 66.8% - 2 x 1 x 1 x 960 x 64 x 64 x 320 | |
748 µs - 63.4% - 2 x 1 x 1 x 1920 x 32 x 32 x 640 | |
1.0 ms - 35.0% - 2 x 3 x 3 x 640 x 16 x 16 x 640 - stride 2 | |
742 µs - 48.0% - 2 x 3 x 3 x 320 x 32 x 32 x 320 - stride 2 | |
2.1 ms - 16.9% - 2 x 3 x 3 x 1280 x 8 x 8 x 1280 - stride 2 | |
531 µs - 59.5% - 2 x 1 x 1 x 1280 x 32 x 32 x 640 | |
413 µs - 57.4% - 2 x 1 x 1 x 960 x 32 x 32 x 640 | |
528 µs - 44.9% - 2 x 1 x 1 x 1920 x 16 x 16 x 1280 | |
378 µs - 20.9% - 2 x 1 x 1 x 2560 x 8 x 8 x 1280 | |
208 µs - 38.0% - 2 x 1 x 1 x 640 x 16 x 16 x 1280 | |
193 µs - 40.9% - 2 x 1 x 1 x 320 x 32 x 32 x 640 | |
190 µs - 20.7% - 2 x 1 x 1 x 1280 x 8 x 8 x 1280 | |
139 µs - 12.8% - 2 x 3 x 3 x 4 x 64 x 64 x 320 | |
146 µs - 12.2% - 2 x 3 x 3 x 320 x 64 x 64 x 4 | |
""" | |
let rawData_f32 = """ | |
1.5 ms - 94.4% - 2 x 3 x 3 x 320 x 64 x 64 x 320 | |
1.8 ms - 81.2% - 2 x 3 x 3 x 1280 x 16 x 16 x 1280 | |
1.7 ms - 85.0% - 2 x 3 x 3 x 640 x 32 x 32 x 640 | |
5.7 ms - 99.2% - 2 x 3 x 3 x 640 x 64 x 64 x 640 | |
6.4 ms - 88.4% - 2 x 3 x 3 x 1280 x 32 x 32 x 1280 | |
3.0 ms - 96.1% - 2 x 3 x 3 x 640 x 64 x 64 x 320 | |
5.2 ms - 54.4% - 2 x 3 x 3 x 2560 x 16 x 16 x 1280 | |
4.4 ms - 96.4% - 2 x 3 x 3 x 960 x 64 x 64 x 320 | |
7.4 ms - 57.8% - 2 x 3 x 3 x 1920 x 32 x 32 x 640 | |
676 µs - 52.6% - 2 x 3 x 3 x 1280 x 8 x 8 x 1280 | |
3.4 ms - 84.1% - 2 x 3 x 3 x 1280 x 32 x 32 x 640 | |
2.5 ms - 84.4% - 2 x 3 x 3 x 960 x 32 x 32 x 640 | |
3.6 ms - 59.4% - 2 x 3 x 3 x 1920 x 16 x 16 x 1280 | |
2.1 ms - 34.1% - 2 x 3 x 3 x 2560 x 8 x 8 x 1280 | |
317 µs - 49.8% - 2 x 1 x 1 x 640 x 32 x 32 x 640 | |
306 µs - 51.6% - 2 x 1 x 1 x 320 x 64 x 64 x 320 | |
379 µs - 41.7% - 2 x 1 x 1 x 1280 x 16 x 16 x 1280 | |
858 µs - 82.9% - 2 x 3 x 3 x 640 x 16 x 16 x 1280 | |
868 µs - 81.9% - 2 x 3 x 3 x 320 x 32 x 32 x 640 | |
536 µs - 59.0% - 2 x 1 x 1 x 640 x 64 x 64 x 320 | |
773 µs - 40.9% - 2 x 1 x 1 x 2560 x 16 x 16 x 1280 | |
765 µs - 62.0% - 2 x 1 x 1 x 960 x 64 x 64 x 320 | |
826 µs - 57.4% - 2 x 1 x 1 x 1920 x 32 x 32 x 640 | |
1.1 ms - 33.7% - 2 x 3 x 3 x 640 x 16 x 16 x 640 - stride 2 | |
772 µs - 46.1% - 2 x 3 x 3 x 320 x 32 x 32 x 320 - stride 2 | |
2.2 ms - 15.9% - 2 x 3 x 3 x 1280 x 8 x 8 x 1280 - stride 2 | |
568 µs - 55.6% - 2 x 1 x 1 x 1280 x 32 x 32 x 640 | |
445 µs - 53.3% - 2 x 1 x 1 x 960 x 32 x 32 x 640 | |
576 µs - 41.1% - 2 x 1 x 1 x 1920 x 16 x 16 x 1280 | |
424 µs - 18.6% - 2 x 1 x 1 x 2560 x 8 x 8 x 1280 | |
216 µs - 36.5% - 2 x 1 x 1 x 640 x 16 x 16 x 1280 | |
200 µs - 39.5% - 2 x 1 x 1 x 320 x 32 x 32 x 640 | |
205 µs - 19.2% - 2 x 1 x 1 x 1280 x 8 x 8 x 1280 | |
141 µs - 12.6% - 2 x 3 x 3 x 4 x 64 x 64 x 320 | |
156 µs - 11.4% - 2 x 3 x 3 x 320 x 64 x 64 x 4 | |
""" | |
struct Dataset { | |
var rawData: String | |
var precision: MatrixUtilization.Precision | |
} | |
var output_f16: [ConvolutionShape: Float] = [:] | |
var output_f32: [ConvolutionShape: Float] = [:] | |
let datasets = [ | |
Dataset(rawData: rawData_f16, precision: .f16), | |
Dataset(rawData: rawData_f32, precision: .f32), | |
] | |
for dataset in datasets { | |
for var line in dataset.rawData.split(separator: "\n") { | |
removeExcluding("-", from: &line) | |
removeExpectedPrefix("- ", from: &line) | |
let percent = extractExcluding("%", from: &line) | |
let utilization = Float(percent)! / 100 | |
removeExpectedPrefix("% - 2", from: &line) | |
let batchSize = 2 | |
func getNextNumber() -> Int { | |
removeExpectedPrefix(" x ", from: &line) | |
let number = extractExcluding(" ", from: &line) | |
return Int(number)! | |
} | |
let window = getNextNumber() | |
precondition(window == getNextNumber()) | |
let channelsIn = getNextNumber() | |
let heightOut = getNextNumber() | |
let widthOut = getNextNumber() | |
let channelsOut = getNextNumber() | |
var heightIn = heightOut | |
var widthIn = widthOut | |
let suffix = extractExcluding("2", from: &line) | |
if suffix.contains("stride") { | |
heightIn *= 2 | |
widthIn *= 2 | |
} | |
// Append the data to MPS statistics. | |
let data = [batchSize, heightIn, widthIn, channelsIn] | |
let filter = [channelsOut, channelsIn, window, window] | |
let output = [batchSize, heightOut, widthOut, channelsOut] | |
let shape = ConvolutionShape(data: data, filter: filter, output: output) | |
switch dataset.precision { | |
case .f16: | |
output_f16[shape] = utilization | |
case .f32: | |
output_f32[shape] = utilization | |
} | |
} | |
} | |
// Assume MFA has: | |
// - 190% FP16 for 3x3 | |
// - 95% FP32 for 3x3 | |
// - 85% FP16 for 1x1 | |
// - 70% FP32 for 1x1 | |
// | |
// This is a conservative estimate, because MFA might use Winograd 2.78-4.0x. | |
var output: [ConvolutionShape: MatrixUtilization] = [:] | |
for shape in output_f16.keys { | |
let mpsF16 = output_f16[shape]! | |
let mpsF32 = output_f32[shape]! | |
var mfaF16: Float | |
var mfaF32: Float | |
if shape.window == 3 { | |
mfaF16 = 1.90 | |
mfaF32 = 0.95 | |
} else { | |
mfaF16 = 0.85 | |
mfaF32 = 0.70 | |
} | |
output[shape] = MatrixUtilization( | |
mpsF16: mpsF16, mpsF32: mpsF32, mfaF16: mfaF16, mfaF32: mfaF32) | |
} | |
return output | |
} | |
struct SoftmaxUtilization { | |
var mpsF16ALU: Float | |
var mpsF32ALU: Float | |
var mfaF16ALU: Float | |
var mfaF32ALU: Float | |
// We should assume MFA gets 95% ALU for half precision inside FlashAttention, | |
// 90% ALU for single precision. | |
init( | |
mpsF16Bandwidth: Float, // in GB/s | |
mpsF32Bandwidth: Float, // in GB/s | |
mfaF16ALU: Float = 0.95, | |
mfaF32ALU: Float = 0.90 | |
) { | |
func makeALU(bandwidth: Float, bytesPerElement: Int) -> Float { | |
// Bandwidth sums the read + write bandwidth from RAM or SLC. | |
let elementsPerS = bandwidth * 1e9 / 2 / Float(bytesPerElement) | |
let floatOpsPerS = 10 * elementsPerS | |
return floatOpsPerS / flopsForLatency | |
} | |
self.mpsF16ALU = makeALU(bandwidth: mpsF16Bandwidth, bytesPerElement: 2) | |
self.mpsF32ALU = makeALU(bandwidth: mpsF32Bandwidth, bytesPerElement: 4) | |
self.mfaF16ALU = mfaF16ALU | |
self.mfaF32ALU = mfaF32ALU | |
} | |
} | |
func getSoftmaxShapes() -> [SIMD2<Int>] { | |
return [ | |
SIMD2<Int>(32768, 92), | |
SIMD2<Int>(8192, 92), | |
SIMD2<Int>(2048, 92), | |
SIMD2<Int>(512, 92), | |
SIMD2<Int>(32768, 1713), | |
SIMD2<Int>(8192, 1713), | |
SIMD2<Int>(2048, 1713), | |
SIMD2<Int>(512, 1713), | |
SIMD2<Int>(4096, 4096), | |
SIMD2<Int>(1024, 1024), | |
SIMD2<Int>(256, 256), | |
SIMD2<Int>(64, 64), | |
] | |
} | |
func profileSoftmax(simulationPrecision: MatrixUtilization.Precision) { | |
let softmaxDataType: MPSDataType = (simulationPrecision == .f32) | |
? .float32 : .float16 | |
let shapes = getSoftmaxShapes() | |
let numShapes = shapes.count | |
var graphs: [Graph] = [] | |
for _ in 0..<numShapes { | |
graphs.append(Graph()) | |
} | |
let device = MTLCopyAllDevices().first! | |
let mpsGraphDevice = MPSGraphDevice(mtlDevice: device) | |
let graphsQueue = DispatchQueue( | |
label: "com.philipturner.CalculateDiffusion.graphs") | |
let numCores = min(numShapes, 1) //8) wtf 1 thread is faster than 2 or 8 | |
DispatchQueue.concurrentPerform(iterations: numCores) { z in | |
var opIndex = z | |
while opIndex < numShapes { | |
defer { opIndex += numCores } | |
let graph = graphsQueue.sync { graphs[opIndex].graph } | |
let shape = shapes[opIndex] | |
let type: MPSDataType = softmaxDataType | |
let nsShape = [shape[0], shape[1]].map(NSNumber.init) | |
let source = graph.placeholder( | |
shape: nsShape, dataType: type, name: "data") | |
let output = graph.softMax(with: source, axis: 1, name: "softmax") | |
var tensorData: [MPSGraphTensorData] = [] | |
let dataSize = (softmaxDataType == .float32) ? 4 : 2 | |
for _ in 0..<2 { | |
let numBytes = shape[0] * shape[1] * dataSize | |
let emptyData = Data(count: numBytes) | |
tensorData.append(MPSGraphTensorData( | |
device: mpsGraphDevice, data: emptyData, shape: nsShape, | |
dataType: softmaxDataType)) | |
} | |
let compileDesc = MPSGraphCompilationDescriptor() | |
if mpsGraphOptimizationLevel1 { | |
compileDesc.optimizationLevel = .level1 | |
} else { | |
compileDesc.optimizationLevel = .level0 | |
} | |
compileDesc.waitForCompilationCompletion = true | |
let feeds = [source: MPSGraphShapedType( | |
shape: nsShape, dataType: softmaxDataType)] | |
let executable = graph.compile( | |
with: mpsGraphDevice, feeds: feeds, targetTensors: [output], | |
targetOperations: nil, compilationDescriptor: compileDesc) | |
graphsQueue.sync { | |
graphs[opIndex].executable = executable | |
graphs[opIndex].compiled = true | |
graphs[opIndex].inputData = [tensorData[0]] | |
graphs[opIndex].inputTensors = [source] | |
graphs[opIndex].outputData = [tensorData[1]] | |
graphs[opIndex].outputTensors = [output] | |
} | |
} | |
} | |
let commandQueue = device.makeCommandQueue()! | |
func dispatch(graph: Graph, sync: Bool) { | |
let inputs = graph.inputData | |
let results = graph.outputData | |
for tensorData in inputs + results { | |
precondition(tensorData.dataType == softmaxDataType) | |
} | |
let desc = MPSGraphExecutableExecutionDescriptor() | |
desc.waitUntilCompleted = sync | |
let mpsCommandBuffer = MPSCommandBuffer(from: commandQueue) | |
graph.executable.encode( | |
to: mpsCommandBuffer, inputs: inputs, results: results, | |
executionDescriptor: nil) | |
mpsCommandBuffer.commit() | |
if sync { | |
mpsCommandBuffer.waitUntilCompleted() | |
} | |
} | |
// Run through the entire data set once to warm up. Then, run 5 warmup | |
// iterations, and finally profile, with 16 iterations per trial, and | |
// 4 trials total. We don't have control over the encoding process so we must | |
// go by CPU -> GPU -> CPU latency. Dividing by 16 should amortize this a bit. | |
do { | |
let start = CACurrentMediaTime() | |
for graph in graphs[0..<graphs.count - 1] { | |
dispatch(graph: graph, sync: true) | |
} | |
dispatch(graph: graphs.last!, sync: true) | |
let end = CACurrentMediaTime() | |
print() | |
print("Startup latency: \(latencyRepr(end - start))") | |
} | |
// Amortize the execution latency over 'iterations' duplicates of the command. | |
func profile(iterations: Int, trials: Int) { | |
print() | |
print("MPSGraph Optimization Level: \(mpsGraphOptimizationLevel1 ? 1 : 0)") | |
print("Latency - GFLOPS / Theoretical - Softmax Shape") | |
for (i, graph) in graphs.enumerated() { | |
var minTime = Double.infinity | |
for _ in 0..<trials { | |
let start = CACurrentMediaTime() | |
for _ in 0..<iterations - 1 { | |
dispatch(graph: graph, sync: false) | |
} | |
dispatch(graph: graph, sync: true) | |
let end = CACurrentMediaTime() | |
let time = (end - start) / Double(iterations) | |
minTime = min(minTime, time) | |
} | |
let shape = shapes[i] | |
let numElements = shape[0] * shape[1] | |
let elementBytes = (simulationPrecision == .f32) ? 4 : 2 | |
let bytesTransferred = 2 * elementBytes * numElements | |
let bandwidth = Double(bytesTransferred) / 1e9 / minTime | |
var repr = latencyRepr(minTime) | |
repr += " - \(String(format: "%.1f", bandwidth)) GB/s" | |
repr += " - \(shape[0]) x \(shape[1])" | |
print(repr) | |
} | |
} | |
profile(iterations: 5, trials: 1) | |
profile(iterations: 16, trials: 4) | |
profile(iterations: 64, trials: 8) | |
// Need to disable API validation and use Swift release mode to get accurate | |
// readings. Softmaxes are somewhat latency-bound. | |
// FP16: | |
// 206 µs - 58.6 GB/s - 32768 x 92 | |
// 84 µs - 36.0 GB/s - 8192 x 92 | |
// 75 µs - 10.1 GB/s - 2048 x 92 | |
// 73 µs - 2.6 GB/s - 512 x 92 | |
// 1.6 ms - 142.5 GB/s - 32768 x 1713 | |
// 420 µs - 133.5 GB/s - 8192 x 1713 | |
// 146 µs - 95.9 GB/s - 2048 x 1713 | |
// 77 µs - 45.4 GB/s - 512 x 1713 | |
// 480 µs - 139.9 GB/s - 4096 x 4096 | |
// 74 µs - 56.8 GB/s - 1024 x 1024 | |
// 74 µs - 3.5 GB/s - 256 x 256 | |
// 74 µs - 0.2 GB/s - 64 x 64 | |
// FP32: | |
// 206 µs - 116.8 GB/s - 32768 x 92 | |
// 80 µs - 75.4 GB/s - 8192 x 92 | |
// 76 µs - 19.8 GB/s - 2048 x 92 | |
// 79 µs - 4.8 GB/s - 512 x 92 | |
// 2.5 ms - 180.1 GB/s - 32768 x 1713 | |
// 617 µs - 181.8 GB/s - 8192 x 1713 | |
// 162 µs - 173.4 GB/s - 2048 x 1713 | |
// 78 µs - 89.4 GB/s - 512 x 1713 | |
// 777 µs - 172.8 GB/s - 4096 x 4096 | |
// 75 µs - 111.5 GB/s - 1024 x 1024 | |
// 70 µs - 7.5 GB/s - 256 x 256 | |
// 72 µs - 0.5 GB/s - 64 x 64 | |
} | |
struct TransposeShape: Equatable & Hashable { | |
var batch: Int | |
var firstInInput: Int | |
var firstInOutput: Int | |
var heads: Int | |
init(inputShape: [Int]) { | |
precondition(inputShape.count == 4) | |
self.batch = inputShape[0] | |
self.firstInInput = inputShape[1] | |
self.firstInOutput = inputShape[2] | |
self.heads = inputShape[3] | |
} | |
func repr() -> String { | |
"\(batch) x \(firstInInput) x \(firstInOutput) x \(heads)" | |
} | |
var inputShape: [Int] { | |
[batch, firstInInput, firstInOutput, heads] | |
} | |
var inputShapeNS: [NSNumber] { | |
[batch, firstInInput, firstInOutput, heads].map(NSNumber.init) | |
} | |
var outputShape: [Int] { | |
[batch, firstInOutput, firstInInput, heads] | |
} | |
var outputShapeNS: [NSNumber] { | |
[batch, firstInOutput, firstInInput, heads].map(NSNumber.init) | |
} | |
} | |
// Latency in seconds. | |
struct TransposeLatency { | |
var mpsF16: Float | |
var mpsF32: Float | |
var mfaF16: Float // always zero | |
var mfaF32: Float // always zero | |
// Enter in microseconds. | |
init(f16: Int, f32: Int) { | |
self.mpsF16 = Float(f16) / 1e6 | |
self.mpsF32 = Float(f32) / 1e6 | |
self.mfaF16 = 0 | |
self.mfaF32 = 0 | |
} | |
} | |
// Print both the microseconds/milliseconds and something the script can parse. | |
func profileTranspose( | |
operations: [Operation], | |
simulationPrecision: MatrixUtilization.Precision | |
) { | |
let transposeDataType: MPSDataType = (simulationPrecision == .f32) | |
? .float32 : .float16 | |
var shapesDict: [TransposeShape: Bool] = [:] | |
for operation in operations where operation.name == "TRANSPOSE" { | |
precondition(operation.inputs.count == 1) | |
precondition(operation.outputs.count == 1) | |
precondition(operation.inputs[0].dimensions.count == 4) | |
precondition(operation.outputs[0].dimensions.count == 4) | |
let inShape = operation.inputs[0].dimensions | |
let outShape = operation.outputs[0].dimensions | |
precondition(inShape[0] == outShape[0]) | |
precondition(inShape[1] == outShape[2]) | |
precondition(inShape[2] == outShape[1]) | |
precondition(inShape[3] == outShape[3]) | |
shapesDict[TransposeShape(inputShape: inShape)] = true | |
} | |
let shapes = shapesDict.keys.map { $0 } | |
let numShapes = shapes.count | |
var graphs: [Graph] = [] | |
for _ in 0..<numShapes { | |
graphs.append(Graph()) | |
} | |
let device = MTLCopyAllDevices().first! | |
let mpsGraphDevice = MPSGraphDevice(mtlDevice: device) | |
let graphsQueue = DispatchQueue( | |
label: "com.philipturner.CalculateDiffusion.graphs") | |
let numCores = min(numShapes, 1) //8) wtf 1 thread is faster than 2 or 8 | |
DispatchQueue.concurrentPerform(iterations: numCores) { z in | |
var opIndex = z | |
while opIndex < numShapes { | |
defer { opIndex += numCores } | |
let graph = graphsQueue.sync { graphs[opIndex].graph } | |
let shape = shapes[opIndex] | |
let type: MPSDataType = transposeDataType | |
let source = graph.placeholder( | |
shape: shape.inputShapeNS, dataType: type, name: "data") | |
let output = graph.transposeTensor( | |
source, dimension: 1, withDimension: 2, name: "transpose") | |
var tensorData: [MPSGraphTensorData] = [] | |
let dataSize = (transposeDataType == .float32) ? 4 : 2 | |
for i in 0..<2 { | |
let numBytes = shape.inputShape.reduce(1, *) * dataSize | |
let emptyData = Data(count: numBytes) | |
let shapeNS = (i == 0) ? shape.inputShapeNS : shape.outputShapeNS | |
tensorData.append(MPSGraphTensorData( | |
device: mpsGraphDevice, data: emptyData, shape: shapeNS, | |
dataType: transposeDataType)) | |
} | |
let compileDesc = MPSGraphCompilationDescriptor() | |
if mpsGraphOptimizationLevel1 { | |
compileDesc.optimizationLevel = .level1 | |
} else { | |
compileDesc.optimizationLevel = .level0 | |
} | |
compileDesc.waitForCompilationCompletion = true | |
let feeds = [source: MPSGraphShapedType( | |
shape: shape.inputShapeNS, dataType: transposeDataType)] | |
let executable = graph.compile( | |
with: mpsGraphDevice, feeds: feeds, targetTensors: [output], | |
targetOperations: nil, compilationDescriptor: compileDesc) | |
graphsQueue.sync { | |
graphs[opIndex].executable = executable | |
graphs[opIndex].compiled = true | |
graphs[opIndex].inputData = [tensorData[0]] | |
graphs[opIndex].inputTensors = [source] | |
graphs[opIndex].outputData = [tensorData[1]] | |
graphs[opIndex].outputTensors = [output] | |
} | |
} | |
} | |
let commandQueue = device.makeCommandQueue()! | |
func dispatch(graph: Graph, sync: Bool) { | |
let inputs = graph.inputData | |
let results = graph.outputData | |
for tensorData in inputs + results { | |
precondition(tensorData.dataType == transposeDataType) | |
} | |
let desc = MPSGraphExecutableExecutionDescriptor() | |
desc.waitUntilCompleted = sync | |
let mpsCommandBuffer = MPSCommandBuffer(from: commandQueue) | |
graph.executable.encode( | |
to: mpsCommandBuffer, inputs: inputs, results: results, | |
executionDescriptor: nil) | |
mpsCommandBuffer.commit() | |
if sync { | |
mpsCommandBuffer.waitUntilCompleted() | |
} | |
} | |
// Run through the entire data set once to warm up. Then, run 5 warmup | |
// iterations, and finally profile, with 16 iterations per trial, and | |
// 4 trials total. We don't have control over the encoding process so we must | |
// go by CPU -> GPU -> CPU latency. Dividing by 16 should amortize this a bit. | |
do { | |
let start = CACurrentMediaTime() | |
for graph in graphs[0..<graphs.count - 1] { | |
dispatch(graph: graph, sync: true) | |
} | |
dispatch(graph: graphs.last!, sync: true) | |
let end = CACurrentMediaTime() | |
print() | |
print("Startup latency: \(latencyRepr(end - start))") | |
} | |
// Amortize the execution latency over 'iterations' duplicates of the command. | |
func profile(iterations: Int, trials: Int) { | |
print() | |
print("MPSGraph Optimization Level: \(mpsGraphOptimizationLevel1 ? 1 : 0)") | |
print("Latency - Latency Seconds - Transpose Shape") | |
for (i, graph) in graphs.enumerated() { | |
var minTime = Double.infinity | |
for _ in 0..<trials { | |
let start = CACurrentMediaTime() | |
for _ in 0..<iterations - 1 { | |
dispatch(graph: graph, sync: false) | |
} | |
dispatch(graph: graph, sync: true) | |
let end = CACurrentMediaTime() | |
let time = (end - start) / Double(iterations) | |
minTime = min(minTime, time) | |
} | |
var repr = latencyRepr(minTime) | |
repr += " - \(String(format: "%.6f", minTime))" | |
repr += " - \(shapes[i].repr())" | |
print(repr) | |
} | |
} | |
// Need to disable API validation and use Swift release mode to get accurate | |
// readings. Transposes are latency-bound. | |
profile(iterations: 5, trials: 1) | |
profile(iterations: 16, trials: 4) | |
profile(iterations: 128, trials: 16) | |
} | |
// MARK: - Profile HGEMM on MPSGraph and libMetalFlashAttention | |
func loadLibMetalFlashAttention(device: MTLDevice) -> MTLLibrary { | |
let parentDir = "/Users/philipturner/Documents/Xcode Projects" | |
let projectDir = "/CalculateDiffusion/CalculateDiffusion" | |
let file = "/libMetalFlashAttention.metallib" | |
let url = URL(filePath: parentDir + projectDir + file) | |
return try! device.makeLibrary(URL: url) | |
} | |
struct MatrixShape: Equatable & Hashable { | |
var B: Int? | |
var M: Int | |
var N: Int | |
var K: Int | |
// Whether the stride for instances of matrix B is nonzero. | |
var strideInput2: Bool | |
init(B: Int?, M: Int, N: Int, K: Int, _ strideInput2: Bool) { | |
self.B = B | |
self.M = M | |
self.N = N | |
self.K = K | |
self.strideInput2 = strideInput2 | |
} | |
init(operation: Operation) { | |
var shape1 = operation.inputs[0].dimensions | |
var shape2 = operation.inputs[1].dimensions | |
var shapeO = operation.outputs[0].dimensions | |
if shape2.count == 2 { | |
// pass | |
} else if shape2.count == 3 { | |
precondition(shape2[0] == 1) | |
} else if shape2.count == 4 { | |
precondition(shape1[0] == 1) | |
precondition(shape2[0] == 1) | |
precondition(shapeO[0] == 1) | |
precondition(shape1[1] == shape2[1]) | |
precondition(shape1[1] == shapeO[1]) | |
} else { | |
fatalError("Unrecognized shape.") | |
} | |
if shape1.count != shape2.count { | |
if shape1.count == 3 && shape2.count == 2 { | |
shape2 = [1] + shape2 | |
} else { | |
fatalError("Unrecognized shape.") | |
} | |
} | |
precondition(shape1.count == shapeO.count) | |
if shape1.count > 2 { | |
if shape1[0] == 1 && shape2[0] == 1 && shapeO[0] == 1 { | |
shape1 = Array(shape1[1...]) | |
shape2 = Array(shape2[1...]) | |
shapeO = Array(shapeO[1...]) | |
} | |
} | |
if shape1.count > 2 { | |
precondition(shape1[0] == shapeO[0]) | |
} | |
var K_index1Rev: Int | |
var K_index2Rev: Int | |
let shape1Reversed = Array(shape1.reversed()) | |
let shape2Reversed = Array(shape2.reversed()) | |
if shape1Reversed[0] == shape2Reversed[0] { | |
K_index1Rev = 0 | |
K_index2Rev = 0 | |
} else if shape1Reversed[1] == shape2Reversed[0] { | |
K_index1Rev = 1 | |
K_index2Rev = 0 | |
} else if shape1Reversed[0] == shape2Reversed[1] { | |
K_index1Rev = 0 | |
K_index2Rev = 1 | |
} else if shape1Reversed[1] == shape2Reversed[1] { | |
K_index1Rev = 1 | |
K_index2Rev = 1 | |
} else { | |
fatalError("No matching dimensions.") | |
} | |
let shapeOReversed = Array(shapeO.reversed()) | |
if shape1Reversed[1 - K_index1Rev] == shape2Reversed[1 - K_index2Rev] { | |
let maybeM = shape1Reversed[1 - K_index1Rev] | |
let maybeN = shape2Reversed[1 - K_index2Rev] | |
if shapeOReversed[0] == maybeN && shapeOReversed[1] == maybeM { | |
// Good to go; this is the correct estimate of K. | |
} else { | |
K_index1Rev = 1 - K_index1Rev | |
K_index2Rev = 1 - K_index2Rev | |
} | |
} | |
self.M = shape1Reversed[1 - K_index1Rev] | |
self.N = shape2Reversed[1 - K_index2Rev] | |
self.K = shape1Reversed[K_index1Rev] | |
precondition(shapeOReversed[0] == N) | |
precondition(shapeOReversed[1] == M) | |
if shape1.count == 2 { | |
self.B = nil | |
self.strideInput2 = false | |
} else if shape1.count == 3 { | |
precondition(shape1[0] != 1) | |
self.B = shape1[0] | |
self.strideInput2 = (shape2[0] > 1) | |
if shape2[0] > 1 { | |
precondition(shape1[0] == shape2[0]) | |
} | |
} else { | |
fatalError("This should never happen.") | |
} | |
} | |
// Print text that I can use directly as source code. | |
func initRepr() -> String { | |
let reprB = (B != nil) ? "\(B!)" : "nil" | |
return """ | |
MatrixShape(B: \(reprB), M: \(M), N: \(N), K: \(K), \(strideInput2)) | |
""" | |
} | |
} | |
// Like the other functions, this should print a list of latencies with MPS | |
// ALU utilizations and GEMM shapes. To get accurate benchmarks, we need to | |
// compile the Swift code in RELEASE mode and disable Metal API validation. | |
// | |
// WARNING: I am enabling Swift DEBUG mode and API validation initially. | |
// | |
// Profile both MPSGraph and MFA, with all precisions. To allow incremental | |
// progress over the test suite, specify a range of shapes that can be tested | |
// during the call. | |
// | |
// TODO: Change MatrixUtilization so it accounts for batch size too. | |
// | |
// This function also should print out an initializer for a dictionary of | |
// `MatrixUtilization`. I would need a different initializer based on whether | |
// I'm restricting it to the monolithic kernel. It assumes that transposing of | |
// the matrix barely changes performance in MFA, which seems plausible. It also | |
// assumes that fusing the bias vector with the matmul is trivial. Both the MPS | |
// and MFA multiplications elide the transpose and fused activation. | |
func profileGEMM(operations: [Operation]) { | |
let gemmOperations = operations.filter { $0.name == "GEMM" } | |
// Keep the shapes in a consistent order for presenting. | |
// - key: shape | |
// - value: first time it appears in the trace | |
var uniqueOperationsDict: [MatrixShape: Int] = [:] | |
var uniqueOperations: [MatrixShape] = [] | |
for (i, operation) in gemmOperations.enumerated() { | |
let shape = MatrixShape(operation: operation) | |
if uniqueOperationsDict[shape] == nil { | |
uniqueOperationsDict[shape] = i | |
uniqueOperations.append(shape) | |
} | |
} | |
class GPUContext { | |
var device: MTLDevice | |
var commandQueue: MTLCommandQueue | |
var library: MTLLibrary | |
var mpsGraphDevice: MPSGraphDevice | |
init(device: MTLDevice) { | |
self.device = device | |
self.commandQueue = device.makeCommandQueue()! | |
self.library = loadLibMetalFlashAttention(device: device) | |
self.mpsGraphDevice = MPSGraphDevice(mtlDevice: device) | |
} | |
} | |
class Resources { | |
static let monolithicVariant: Resources | |
.Variant = benchmarkGEMM_monolithicVariant | |
var context: GPUContext | |
var shape: MatrixShape | |
var mfaEnsemble: Bool // true: run 3x the iters and report the fastest time | |
var precision: MatrixUtilization.Precision | |
var mpsDataType: MPSDataType { precision == .f32 ? .float32 : .float16 } | |
var executable: MPSGraphExecutable | |
var pipelines: [Variant: MTLComputePipelineState] // 1-3 to test | |
var matrixOffsets: [SIMD4<UInt64>] // one array per GEMM in the batch | |
var gridSizes: [Variant: MTLSize] // threadgroups per grid | |
var pipelineNames: [Variant: String] | |
var bufferA: MTLBuffer | |
var bufferB: MTLBuffer | |
var bufferC_mps: MTLBuffer | |
var bufferC_mfa: MTLBuffer | |
var tensorData: [MPSGraphTensorData] | |
var inputs: [MPSGraphTensorData] { [tensorData[0], tensorData[1]] } | |
var results: [MPSGraphTensorData] { [tensorData[2]] } | |
#if arch(arm64) | |
// Bypass compiler complaint. | |
typealias Half = Float16 | |
#else | |
typealias Half = Float | |
#endif | |
static let randomTape16: UnsafeMutableRawPointer/*<Half>*/ = { | |
var output = malloc(15687 * 2).assumingMemoryBound(to: Half.self) | |
for i in 0..<15687 { | |
output[i] = Half.random(in: 0..<1) | |
} | |
return .init(output) | |
}() | |
static let randomTape32: UnsafeMutableRawPointer/*<Float>*/ = { | |
var output = malloc(15687 * 4).assumingMemoryBound(to: Float.self) | |
for i in 0..<15687 { | |
output[i] = Float.random(in: 0..<1) | |
} | |
return .init(output) | |
}() | |
// This waits until all blocking operations are finished, so initialize | |
// multiple test instances in parallel. | |
init( | |
context: GPUContext, | |
shape: MatrixShape, | |
mfaEnsemble: Bool, | |
precision: MatrixUtilization.Precision | |
) { | |
self.context = context | |
self.shape = shape | |
self.mfaEnsemble = mfaEnsemble | |
self.precision = precision | |
var dimsA = [shape.M, shape.K] | |
var dimsB = [shape.K, shape.N] | |
var dimsC = [shape.M, shape.N] | |
if let shapeB = shape.B { | |
dimsA = [shapeB] + dimsA | |
dimsB = [shape.strideInput2 ? shapeB : 1] + dimsB | |
dimsC = [shapeB] + dimsC | |
} | |
func matrixStride(dims: [Int], actuallyStride: Bool) -> Int { | |
actuallyStride ? (dims.last! * dims.reversed()[1]) : 0 | |
} | |
let _si2 = shape.strideInput2 | |
let strideA = matrixStride(dims: dimsA, actuallyStride: true) | |
let strideB = matrixStride(dims: dimsB, actuallyStride: _si2) | |
let strideC = matrixStride(dims: dimsC, actuallyStride: true) | |
if let shapeB = shape.B { | |
matrixOffsets = (0..<shapeB).map({ z in | |
SIMD4( | |
UInt64(z * strideA), UInt64(z * strideB), UInt64(z * strideC), 0) | |
} as (Int) -> SIMD4<UInt64>) | |
} else { | |
matrixOffsets = [] | |
} | |
func bufferBytes(dims: [Int]) -> Int { | |
dims.reduce(1, *) * (precision == .f32 ? 4 : 2) | |
} | |
func makeBuffer(dims: [Int]) -> MTLBuffer { | |
let bytes = bufferBytes(dims: dims) | |
return context.device.makeBuffer(length: bytes)! | |
} | |
self.bufferA = makeBuffer(dims: dimsA) | |
self.bufferB = makeBuffer(dims: dimsB) | |
self.bufferC_mps = makeBuffer(dims: dimsC) | |
self.bufferC_mfa = makeBuffer(dims: dimsC) | |
//func fillBuffer<T: BinaryFloatingPoint>( | |
// _ buffer: MTLBuffer, type: T.Type | |
//) where T.RawSignificand: FixedWidthInteger { | |
// let elements = buffer.length / MemoryLayout<T>.stride | |
// let pointer = buffer.contents().assumingMemoryBound(to: T.self) | |
// for i in 0..<elements { | |
// pointer[i] = T.random(in: 0..<1) | |
// } | |
//} | |
// | |
// The Swift compiler is broken. | |
// Also, this is super slow in debug mode. Insted, we create a repeating | |
// tape and `memcpy` numerous times until finishing the buffer. | |
func fillBuffer(_ buffer: MTLBuffer) { | |
let tapeLength = (precision == .f32) ? (15687 * 4) : (15687 * 2) | |
let tape = (precision == .f32) ? | |
Resources.randomTape32 : Resources.randomTape16 | |
var remaining = buffer.length | |
var cursor = buffer.contents() | |
while remaining > 0 { | |
defer { | |
remaining -= tapeLength | |
cursor += tapeLength | |
} | |
memcpy(cursor, tape, min(tapeLength, remaining)) | |
} | |
} | |
fillBuffer(bufferA) | |
fillBuffer(bufferB) | |
func makeTensorData( | |
_ buffer: MTLBuffer, dims: [Int], dataType: MPSDataType | |
) -> MPSGraphTensorData { | |
let nsShape = dims.map(NSNumber.init) | |
return MPSGraphTensorData(buffer, shape: nsShape, dataType: dataType) | |
} | |
let mpsDataType: MPSDataType = (precision == .f32) ? .float32 : .float16 | |
self.tensorData = [ | |
makeTensorData(bufferA, dims: dimsA, dataType: mpsDataType), | |
makeTensorData(bufferB, dims: dimsB, dataType: mpsDataType), | |
makeTensorData(bufferC_mps, dims: dimsC, dataType: mpsDataType) | |
] | |
let mpsGraph = MPSGraph() | |
let nsShapeA = dimsA.map(NSNumber.init) | |
let nsShapeB = dimsB.map(NSNumber.init) | |
let nsShapeC = dimsC.map(NSNumber.init) | |
let tensorA = mpsGraph.placeholder( | |
shape: nsShapeA, dataType: mpsDataType, name: "A") | |
let tensorB = mpsGraph.placeholder( | |
shape: nsShapeB, dataType: mpsDataType, name: "B") | |
let tensorC = mpsGraph.matrixMultiplication( | |
primary: tensorA, secondary: tensorB, name: "C") | |
let compileDesc = MPSGraphCompilationDescriptor() | |
compileDesc.optimizationLevel = .level1 | |
compileDesc.waitForCompilationCompletion = true | |
let feeds = [ | |
tensorA: MPSGraphShapedType(shape: nsShapeA, dataType: mpsDataType), | |
tensorB: MPSGraphShapedType(shape: nsShapeB, dataType: mpsDataType) | |
] | |
let targetTensors = [tensorC] | |
self.executable = mpsGraph.compile( | |
with: context.mpsGraphDevice, feeds: feeds, | |
targetTensors: targetTensors, targetOperations: nil, | |
compilationDescriptor: compileDesc) | |
let constants = MTLFunctionConstantValues() | |
var _m = UInt32(shape.M) | |
var _n = UInt32(shape.N) | |
var _k = UInt32(shape.K) | |
constants.setConstantValue(&_m, type: .uint, index: 100) | |
constants.setConstantValue(&_n, type: .uint, index: 101) | |
constants.setConstantValue(&_k, type: .uint, index: 102) | |
// Function constants we must set because of a bug in the Metal compiler. | |
var _garbage = UInt64(0) | |
constants.setConstantValue(&_garbage, type: .ushort, index: 201) | |
constants.setConstantValue(&_garbage, type: .ushort, index: 211) | |
constants.setConstantValue(&_garbage, type: .ushort, index: 9000) | |
let prefix = (precision == .f32) ? "sgemm_" : "hgemm_" | |
if shape.B != nil { | |
precondition(shape.B! > 1) | |
} | |
let suffix = (shape.B != nil) ? "_batched" : "" | |
var variants: [Variant] | |
if mfaEnsemble { | |
variants = [.mfa16x16, .mfa32x32, .mfa48x48] | |
} else { | |
variants = [Resources.monolithicVariant] | |
} | |
self.pipelines = [:] | |
self.gridSizes = [:] | |
self.pipelineNames = [:] | |
for variant in variants { | |
var block: String | |
switch precision { | |
case .f32: | |
switch variant { | |
case .mps: fatalError() | |
case .mfa16x16: block = "16x48" | |
case .mfa32x32: block = "32x32" | |
case .mfa48x48: block = "48x24" | |
} | |
case .f16: | |
switch variant { | |
case .mps: fatalError() | |
case .mfa16x16: block = "16x64" | |
case .mfa32x32: block = "32x32" | |
case .mfa48x48: block = "48x32" | |
} | |
} | |
let name = prefix + block + suffix | |
self.pipelineNames[variant] = name | |
let function = try! context.library.makeFunction( | |
name: name, constantValues: constants) | |
let pipeline = try! context.device.makeComputePipelineState( | |
function: function) | |
self.pipelines[variant] = pipeline | |
func ceilDivide(target: Int, granularity: Int) -> Int { | |
(target + granularity - 1) / granularity | |
} | |
var blockMN: Int | |
switch variant { | |
case .mps: fatalError() | |
case .mfa16x16: blockMN = 16 | |
case .mfa32x32: blockMN = 32 | |
case .mfa48x48: blockMN = 48 | |
} | |
let gridSize = MTLSize( | |
width: ceilDivide(target: shape.N, granularity: blockMN), | |
height: ceilDivide(target: shape.M, granularity: blockMN), | |
depth: shape.B ?? 1) | |
self.gridSizes[variant] = gridSize | |
} | |
} | |
enum DispatchType { | |
case mps | |
case mfa | |
func repr() -> String { | |
switch self { | |
case .mps: | |
return "MPS" | |
case .mfa: | |
return "MFA" | |
} | |
} | |
} | |
typealias Variant = BlockSizeVariant | |
private var commandQueue: MTLCommandQueue! | |
private var iterations: Int! | |
private var dispatchType: DispatchType! | |
// Aggregate data about latency between each dispatch type that's been run. | |
// Each trial will update the value in the dictionary. | |
private var latenciesDict: [Variant: Double] = [:] | |
static let queue = DispatchQueue( | |
label: "com.philipturner.CalculateDiffusion.Resources.queue") | |
func mpsLatency() -> Double { | |
return latenciesDict[.mps]! | |
} | |
func mfaLatency() -> Double { | |
var latency = latenciesDict[Resources.monolithicVariant]! | |
for v in BlockSizeVariant.allCases { | |
if v == Resources.monolithicVariant { continue } | |
if v == .mps { continue } | |
if let _latency = latenciesDict[v] { | |
latency = min(latency, _latency) | |
} | |
} | |
return latency | |
} | |
// Text suitable for presentation. | |
func cleanRepr(variant: Variant) -> String { | |
let latency = latenciesDict[variant]! | |
var repr = latencyRepr(latency) | |
let floatOps = 2 * Double((shape.B ?? 1) * shape.M * shape.N * shape.K) | |
let flops = floatOps / latency | |
let percent = 100 * flops / Double(flopsForLatency) | |
repr += " - " + String(format: "%.1f", percent) + "%" | |
repr += " - " + shape.initRepr() | |
return repr | |
} | |
// Text you can copy and paste as Swift code. This requires merging the | |
// data with another `Resources` using the other precision. If you want to | |
// test only a subset of the matrix shapes, you must test that subset for | |
// F16, repeat for F32, then print the output. | |
func dataRepr(other: Resources) -> String { | |
precondition(precision != other.precision) | |
func makeProp(latency: Double) -> Float { | |
let floatOps = 2 * Double((shape.B ?? 1) * shape.M * shape.N * shape.K) | |
let flops = floatOps / latency | |
let proportion = flops / Double(flopsForLatency) | |
return Float(proportion) | |
} | |
var f16One = self | |
var f32One = other | |
if precision == .f32 { | |
swap(&f16One, &f32One) | |
} | |
let mpsF16: Float = makeProp(latency: f16One.mpsLatency()) | |
let mpsF32: Float = makeProp(latency: f32One.mpsLatency()) | |
let mfaF16: Float = makeProp(latency: f16One.mfaLatency()) | |
let mfaF32: Float = makeProp(latency: f32One.mfaLatency()) | |
let utilization = MatrixUtilization( | |
mpsF16: mpsF16, mpsF32: mpsF32, mfaF16: mfaF16, mfaF32: mfaF32) | |
return shape.initRepr() + ": " + utilization.initRepr() + "," | |
} | |
func prepareDispatch( | |
iterations: Int, | |
dispatchType: DispatchType | |
) { | |
self.iterations = iterations | |
self.dispatchType = dispatchType | |
} | |
// Ensure you don't forget to set `prepareDispatch` next time. | |
func resetDispatch() { | |
self.iterations = nil | |
self.dispatchType = nil | |
} | |
// MPS commands will be encoded like how Draw Things dispatches commands - a | |
// new command buffer per command. This may be the source of the latency | |
// bottleneck. | |
func profile(sync: Bool) { | |
switch dispatchType! { | |
case .mps: | |
_dispatchMPS() | |
case .mfa: | |
if mfaEnsemble { | |
_dispatchMFAVariant(.mfa16x16, sync: false) | |
_dispatchMFAVariant(.mfa32x32, sync: false) | |
_dispatchMFAVariant(.mfa48x48, sync: sync) | |
} else { | |
_dispatchMFAVariant(Resources.monolithicVariant, sync: sync) | |
} | |
} | |
} | |
// MPS is always synchonous because we can't force it all into one command | |
// buffer - sigh; the lack of flexibility MPS gives over encoding. | |
private func _dispatchMPS() { | |
let start = CACurrentMediaTime() | |
for i in 0..<iterations { | |
let mpsCommandBuffer = MPSCommandBuffer(from: context.commandQueue) | |
executable.encode( | |
to: mpsCommandBuffer, inputs: inputs, results: results, | |
executionDescriptor: nil) | |
mpsCommandBuffer.commit() | |
if i == iterations - 1 { | |
mpsCommandBuffer.waitUntilCompleted() | |
} | |
} | |
let end = CACurrentMediaTime() | |
let time = (end - start) / Double(iterations) | |
Resources.queue.sync { | |
let previous = latenciesDict[.mps] ?? Double.infinity | |
latenciesDict[.mps] = min(time, previous) | |
} | |
} | |
private func _dispatchMFAVariant(_ variant: Variant, sync: Bool) { | |
let commandBuffer = context.commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
for i in 0..<iterations { | |
encoder.setComputePipelineState(pipelines[variant]!) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC_mfa, offset: 0, index: 2) | |
if shape.B != nil && shape.B! > 1 { | |
let elementSize = MemoryLayout<SIMD4<UInt64>>.stride | |
let numBytes = matrixOffsets.count * elementSize | |
encoder.setBytes(&matrixOffsets, length: numBytes, index: 3) | |
} | |
let gridSize = gridSizes[variant]! | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: MTLSizeMake(128, 1, 1)) | |
} | |
encoder.endEncoding() | |
commandBuffer.addCompletedHandler { [self] commandBuffer in | |
let start = commandBuffer.gpuStartTime | |
let end = commandBuffer.gpuEndTime | |
let time = (end - start) / Double(iterations) | |
Resources.queue.sync { | |
let previous = latenciesDict[variant] ?? Double.infinity | |
latenciesDict[variant] = min(time, previous) | |
} | |
} | |
commandBuffer.commit() | |
if sync { | |
commandBuffer.waitUntilCompleted() | |
} | |
} | |
// Check the outputs are the same using Euclidean distance. | |
func validate() { | |
let B = shape.B ?? 1 | |
let M = shape.M | |
let N = shape.N | |
let K = shape.K | |
// Euclidean distance heuristic. | |
let input1 = bufferC_mfa.contents() | |
let input2 = bufferC_mps.contents() | |
var tolerance = Float(B * M * N) * sqrt(Float(K)) | |
if precision == .f32 { | |
tolerance = max(0.001, 3e-7 * tolerance) | |
} else { | |
// Up the tolerance a little for FP16 | |
tolerance = max(0.01, 1e-2 * tolerance) | |
// tolerance = max(0.01, 5e-3 * tolerance) | |
} | |
var bufferSize: Int | |
var n: Int | |
var x: UnsafeMutablePointer<Float> | |
var y: UnsafeMutablePointer<Float> | |
if precision == .f16 { | |
let bufferSize_f16 = bufferC_mps.length | |
let bufferSize_f32 = bufferSize_f16 * 2 | |
bufferSize = bufferSize_f32 | |
n = bufferSize_f32 / MemoryLayout<Float>.stride | |
x = .allocate(capacity: n) | |
y = .allocate(capacity: n) | |
// Partially sourced from: | |
// https://github.com/hollance/TensorFlow-iOS-Example/blob/master/VoiceMetal/VoiceMetal/Float16.swift | |
func copy(dst: UnsafeMutableRawPointer, src: UnsafeMutableRawPointer) { | |
var bufferFloat16 = vImage_Buffer( | |
data: src, height: 1, width: UInt(n), rowBytes: bufferSize_f16) | |
var bufferFloat32 = vImage_Buffer( | |
data: dst, height: 1, width: UInt(n), rowBytes: bufferSize_f32) | |
let error = vImageConvert_Planar16FtoPlanarF( | |
&bufferFloat16, &bufferFloat32, 0) | |
if error != kvImageNoError { | |
fatalError( | |
"Encountered error code \(error) while converting F16 to F32.") | |
} | |
} | |
copy(dst: x, src: .init(mutating: input1)) | |
copy(dst: y, src: .init(mutating: input2)) | |
} else { | |
bufferSize = bufferC_mps.length | |
n = bufferSize / MemoryLayout<Float>.stride | |
x = .allocate(capacity: n) | |
y = .allocate(capacity: n) | |
memcpy(x, input1, bufferSize) | |
memcpy(y, input2, bufferSize) | |
} | |
defer { | |
x.deallocate() | |
y.deallocate() | |
} | |
var difference = [Float](repeating: 0, count: B * M * N) | |
memcpy(&difference, x, B * M * N * 4) | |
var n_copy = Int32(B * M * N) | |
var a = Float(-1) | |
var inc = Int32(1) | |
var inc_copy = inc | |
// Find x + (-1 * y) | |
saxpy_(&n_copy, &a, y, &inc, &difference, &inc_copy) | |
// Find ||x - y|| | |
let distance = Float(snrm2_(&n_copy, &difference, &inc)) | |
let verbose = false | |
guard distance >= tolerance else { | |
if verbose { | |
print("Size B=\(B), M=\(M), N=\(N), K=\(K) succeeded.") | |
for _ in 0..<3 { | |
let index = Int.random(in: 0..<B * M * N) | |
var message = "Element \(index): " | |
message += "MFA=\(x[index]), MPS=\(y[index])" | |
print(message) | |
} | |
} | |
return | |
} | |
// How many elements to print before stopping. | |
let diagnosticElems = 16 | |
print() | |
print("Size B=\(B), M=\(M), N=\(N), K=\(K) failed.") | |
print(shape) | |
print(matrixOffsets) | |
print(gridSizes) | |
print(pipelineNames[Resources.monolithicVariant]) | |
print(""" | |
Vectors did not match: Euclidean distance '\(distance)' > \ | |
tolerance '\(tolerance)'. | |
""") | |
var failedElements = 0 | |
for i in 0..<Int(B * M * N) { | |
if abs(x[i] - y[i]) > 0.001 * distance { | |
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])") | |
failedElements += 1 | |
if failedElements >= diagnosticElems { | |
break | |
} | |
} | |
} | |
if failedElements == 0 { | |
print("Could not find elements with enough separation.") | |
for index in 0..<Int(diagnosticElems) { | |
let i = (index == 0) ? 0 : Int.random(in: 0..<B * M * N) | |
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])") | |
} | |
} | |
} | |
} | |
// uniqueOperations[7] = MatrixShape(B: 2, M: 32, N: 32, K: 32, true) | |
let device = MTLCopyAllDevices().first! | |
let context = GPUContext(device: device) | |
var f16ResourcesArray: [Resources?] = uniqueOperations.map { _ in nil } | |
var f32ResourcesArray: [Resources?] = uniqueOperations.map { _ in nil } | |
// It's faster to run this on single-core. | |
let numShapes = uniqueOperations.count | |
for opIndex in 0..<numShapes { | |
let shape = uniqueOperations[opIndex] | |
let mfaEnsemble = benchmarkGEMM_useEnsemble | |
let f16Resources = Resources( | |
context: context, shape: shape, mfaEnsemble: mfaEnsemble, | |
precision: .f16) | |
let f32Resources = Resources( | |
context: context, shape: shape, mfaEnsemble: mfaEnsemble, | |
precision: .f32) | |
// print("Finished \(opIndex)") | |
f16ResourcesArray[opIndex] = f16Resources | |
f32ResourcesArray[opIndex] = f32Resources | |
} | |
// let testedOpStart = min(uniqueOperations.count, 0) | |
// let testedOpEnd = min(uniqueOperations.count, 10) | |
// let testedOpStart = min(uniqueOperations.count, 0) | |
// let testedOpEnd = min(uniqueOperations.count, 10) | |
let testedOpStart = 0 | |
let testedOpEnd = uniqueOperations.count | |
func profile( | |
iterations: Int, trials: Int, | |
type: Resources.DispatchType, logProgress: Bool = false | |
) { | |
if logProgress { | |
print() | |
print("Running benchmark for \(type.repr()).") | |
} | |
for array in [f16ResourcesArray, f32ResourcesArray] { | |
for i in testedOpStart..<testedOpEnd { | |
let resources = array[i]! | |
resources.prepareDispatch(iterations: iterations, dispatchType: type) | |
for trial in 0..<trials { | |
let sync: Bool = (trial == trials - 1) | |
resources.profile(sync: sync) | |
} | |
resources.resetDispatch() | |
if logProgress { | |
let repr = (resources.precision == .f32) ? "SGEMM" : "HGEMM" | |
print("Shapes finished for \(repr): \(i + 1)") | |
} | |
} | |
} | |
} | |
func validate() { | |
for array in [f16ResourcesArray, f32ResourcesArray] { | |
for i in testedOpStart..<testedOpEnd { | |
array[i]!.validate() | |
} | |
} | |
} | |
profile(iterations: 1, trials: 1, type: .mps) | |
profile(iterations: 1, trials: 1, type: .mfa) | |
validate() | |
profile(iterations: 5, trials: 1, type: .mfa, logProgress: true) | |
profile(iterations: 10, trials: 1, type: .mfa, logProgress: true) | |
profile(iterations: 20, trials: 4, type: .mfa, logProgress: true) | |
// Run MPS after MFA warmed up the GPU, otherwise your estimate will be biased | |
// toward "MFA is faster than MPS". | |
profile(iterations: 5, trials: 1, type: .mps, logProgress: true) | |
profile(iterations: 10, trials: 1, type: .mps, logProgress: true) | |
profile(iterations: 20, trials: 4, type: .mps, logProgress: true) | |
validate() | |
// print(uniqueOperations.firstIndex(where: { $0.B == 2 && $0.M == 4096 && $0.N == 320 && $0.K == 320 })) | |
typealias Variant = Resources.Variant | |
#if true | |
var variants: [Variant] | |
if benchmarkGEMM_useEnsemble { | |
variants = Array(Variant.allCases) | |
} else { | |
variants = [Variant.mps, Resources.monolithicVariant] | |
} | |
for variant in variants { | |
print() | |
print("Performance of variant: \(variant.repr())") | |
print("HGEMM:") | |
for i in 0..<uniqueOperations.count { | |
let f16 = f16ResourcesArray[i]! | |
print(" \(f16.cleanRepr(variant: variant))") | |
} | |
print("SGEMM:") | |
for i in 0..<uniqueOperations.count { | |
let f32 = f32ResourcesArray[i]! | |
print(" \(f32.cleanRepr(variant: variant))") | |
} | |
} | |
#endif | |
#if true | |
print() | |
print("Matrix utilizations for usage inside Swift code:") | |
for i in 0..<uniqueOperations.count { | |
let f16 = f16ResourcesArray[i]! | |
let f32 = f32ResourcesArray[i]! | |
print(f16.dataRepr(other: f32)) | |
} | |
#endif | |
} | |
func getMatrixSpeeds(monolithic: Bool) -> [MatrixShape: MatrixUtilization] { | |
if monolithic { | |
return _getMatrixSpeedsMonolithic() | |
} else { | |
return _getMatrixSpeedsEnsemble() | |
} | |
} | |
// NOTE: Switch to Swift release mode and disable Metal API validation | |
// before generating these benchmarks, otherwise they will be biased in favor | |
// of MFA being faster than MPS. | |
func _getMatrixSpeedsMonolithic() -> [MatrixShape: MatrixUtilization] { | |
return [ | |
MatrixShape(B: nil, M: 2, N: 1280, K: 320, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.017, mfaF32: 0.014), | |
MatrixShape(B: nil, M: 2, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.008, mpsF32: 0.008, mfaF16: 0.021, mfaF32: 0.018), | |
MatrixShape(B: nil, M: 1805, N: 320, K: 768, false): .init( | |
mpsF16: 0.406, mpsF32: 0.425, mfaF16: 0.749, mfaF32: 0.629), | |
MatrixShape(B: nil, M: 1805, N: 640, K: 768, false): .init( | |
mpsF16: 0.496, mpsF32: 0.536, mfaF16: 0.783, mfaF32: 0.649), | |
MatrixShape(B: nil, M: 1805, N: 1280, K: 768, false): .init( | |
mpsF16: 0.577, mpsF32: 0.639, mfaF16: 0.806, mfaF32: 0.661), | |
MatrixShape(B: nil, M: 2, N: 320, K: 1280, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.005, mfaF32: 0.004), | |
MatrixShape(B: nil, M: 2, N: 640, K: 1280, false): .init( | |
mpsF16: 0.004, mpsF32: 0.004, mfaF16: 0.010, mfaF32: 0.009), | |
MatrixShape(B: 2, M: 4096, N: 320, K: 320, false): .init( | |
mpsF16: 0.573, mpsF32: 0.571, mfaF16: 0.802, mfaF32: 0.690), | |
MatrixShape(B: nil, M: 4096, N: 4096, K: 40, false): .init( | |
mpsF16: 0.304, mpsF32: 0.344, mfaF16: 0.615, mfaF32: 0.500), | |
MatrixShape(B: nil, M: 4096, N: 40, K: 4096, false): .init( | |
mpsF16: 0.089, mpsF32: 0.089, mfaF16: 0.495, mfaF32: 0.351), | |
MatrixShape(B: 8, M: 4096, N: 1713, K: 40, true): .init( | |
mpsF16: 0.323, mpsF32: 0.352, mfaF16: 0.528, mfaF32: 0.399), | |
MatrixShape(B: 8, M: 4096, N: 40, K: 1713, true): .init( | |
mpsF16: 0.168, mpsF32: 0.161, mfaF16: 0.585, mfaF32: 0.412), | |
MatrixShape(B: 8, M: 4096, N: 92, K: 40, true): .init( | |
mpsF16: 0.116, mpsF32: 0.111, mfaF16: 0.437, mfaF32: 0.346), | |
MatrixShape(B: 8, M: 4096, N: 40, K: 92, true): .init( | |
mpsF16: 0.117, mpsF32: 0.118, mfaF16: 0.394, mfaF32: 0.310), | |
MatrixShape(B: 2, M: 4096, N: 1280, K: 320, false): .init( | |
mpsF16: 0.672, mpsF32: 0.703, mfaF16: 0.822, mfaF32: 0.673), | |
MatrixShape(B: 2, M: 4096, N: 320, K: 1280, false): .init( | |
mpsF16: 0.695, mpsF32: 0.708, mfaF16: 0.830, mfaF32: 0.686), | |
MatrixShape(B: 2, M: 1024, N: 640, K: 640, false): .init( | |
mpsF16: 0.572, mpsF32: 0.578, mfaF16: 0.804, mfaF32: 0.697), | |
MatrixShape(B: nil, M: 1024, N: 1024, K: 80, false): .init( | |
mpsF16: 0.197, mpsF32: 0.180, mfaF16: 0.544, mfaF32: 0.471), | |
MatrixShape(B: nil, M: 1024, N: 80, K: 1024, false): .init( | |
mpsF16: 0.105, mpsF32: 0.094, mfaF16: 0.427, mfaF32: 0.398), | |
MatrixShape(B: 8, M: 1024, N: 1713, K: 80, true): .init( | |
mpsF16: 0.476, mpsF32: 0.540, mfaF16: 0.639, mfaF32: 0.525), | |
MatrixShape(B: 8, M: 1024, N: 80, K: 1713, true): .init( | |
mpsF16: 0.129, mpsF32: 0.112, mfaF16: 0.671, mfaF32: 0.513), | |
MatrixShape(B: 8, M: 1024, N: 92, K: 80, true): .init( | |
mpsF16: 0.126, mpsF32: 0.116, mfaF16: 0.439, mfaF32: 0.377), | |
MatrixShape(B: 8, M: 1024, N: 80, K: 92, true): .init( | |
mpsF16: 0.094, mpsF32: 0.073, mfaF16: 0.409, mfaF32: 0.357), | |
MatrixShape(B: 2, M: 1024, N: 2560, K: 640, false): .init( | |
mpsF16: 0.692, mpsF32: 0.709, mfaF16: 0.832, mfaF32: 0.686), | |
MatrixShape(B: 2, M: 1024, N: 640, K: 2560, false): .init( | |
mpsF16: 0.693, mpsF32: 0.686, mfaF16: 0.823, mfaF32: 0.681), | |
MatrixShape(B: 2, M: 256, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.582, mpsF32: 0.528, mfaF16: 0.804, mfaF32: 0.713), | |
MatrixShape(B: nil, M: 256, N: 256, K: 160, false): .init( | |
mpsF16: 0.025, mpsF32: 0.025, mfaF16: 0.255, mfaF32: 0.213), | |
MatrixShape(B: nil, M: 256, N: 160, K: 256, false): .init( | |
mpsF16: 0.028, mpsF32: 0.026, mfaF16: 0.185, mfaF32: 0.163), | |
MatrixShape(B: 8, M: 256, N: 1713, K: 160, true): .init( | |
mpsF16: 0.450, mpsF32: 0.492, mfaF16: 0.721, mfaF32: 0.578), | |
MatrixShape(B: 8, M: 256, N: 160, K: 1713, true): .init( | |
mpsF16: 0.185, mpsF32: 0.154, mfaF16: 0.751, mfaF32: 0.533), | |
MatrixShape(B: 8, M: 256, N: 92, K: 160, true): .init( | |
mpsF16: 0.075, mpsF32: 0.076, mfaF16: 0.445, mfaF32: 0.390), | |
MatrixShape(B: 8, M: 256, N: 160, K: 92, true): .init( | |
mpsF16: 0.072, mpsF32: 0.069, mfaF16: 0.408, mfaF32: 0.316), | |
MatrixShape(B: 2, M: 256, N: 5120, K: 1280, false): .init( | |
mpsF16: 0.694, mpsF32: 0.698, mfaF16: 0.830, mfaF32: 0.671), | |
MatrixShape(B: 2, M: 256, N: 1280, K: 5120, false): .init( | |
mpsF16: 0.686, mpsF32: 0.574, mfaF16: 0.813, mfaF32: 0.638), | |
MatrixShape(B: 2, M: 64, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.349, mpsF32: 0.341, mfaF16: 0.659, mfaF32: 0.612), | |
MatrixShape(B: nil, M: 64, N: 64, K: 160, false): .init( | |
mpsF16: 0.002, mpsF32: 0.001, mfaF16: 0.018, mfaF32: 0.017), | |
MatrixShape(B: nil, M: 64, N: 160, K: 64, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.031, mfaF32: 0.029), | |
MatrixShape(B: 8, M: 64, N: 1713, K: 160, true): .init( | |
mpsF16: 0.293, mpsF32: 0.300, mfaF16: 0.608, mfaF32: 0.505), | |
MatrixShape(B: 8, M: 64, N: 160, K: 1713, true): .init( | |
mpsF16: 0.241, mpsF32: 0.230, mfaF16: 0.409, mfaF32: 0.339), | |
MatrixShape(B: 8, M: 64, N: 92, K: 160, true): .init( | |
mpsF16: 0.020, mpsF32: 0.021, mfaF16: 0.160, mfaF32: 0.143), | |
MatrixShape(B: 8, M: 64, N: 160, K: 92, true): .init( | |
mpsF16: 0.019, mpsF32: 0.020, mfaF16: 0.195, mfaF32: 0.169), | |
MatrixShape(B: 2, M: 64, N: 5120, K: 1280, false): .init( | |
mpsF16: 0.573, mpsF32: 0.508, mfaF16: 0.795, mfaF32: 0.611), | |
MatrixShape(B: 2, M: 64, N: 1280, K: 5120, false): .init( | |
mpsF16: 0.477, mpsF32: 0.463, mfaF16: 0.683, mfaF32: 0.631), | |
] | |
} | |
func _getMatrixSpeedsEnsemble() -> [MatrixShape: MatrixUtilization] { | |
return [ | |
MatrixShape(B: nil, M: 2, N: 1280, K: 320, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.022, mfaF32: 0.018), | |
MatrixShape(B: nil, M: 2, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.008, mpsF32: 0.009, mfaF16: 0.035, mfaF32: 0.024), | |
MatrixShape(B: nil, M: 1805, N: 320, K: 768, false): .init( | |
mpsF16: 0.400, mpsF32: 0.425, mfaF16: 0.659, mfaF32: 0.606), | |
MatrixShape(B: nil, M: 1805, N: 640, K: 768, false): .init( | |
mpsF16: 0.495, mpsF32: 0.537, mfaF16: 0.727, mfaF32: 0.691), | |
MatrixShape(B: nil, M: 1805, N: 1280, K: 768, false): .init( | |
mpsF16: 0.580, mpsF32: 0.638, mfaF16: 0.812, mfaF32: 0.729), | |
MatrixShape(B: nil, M: 2, N: 320, K: 1280, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.010, mfaF32: 0.007), | |
MatrixShape(B: nil, M: 2, N: 640, K: 1280, false): .init( | |
mpsF16: 0.004, mpsF32: 0.004, mfaF16: 0.019, mfaF32: 0.014), | |
MatrixShape(B: 2, M: 4096, N: 320, K: 320, false): .init( | |
mpsF16: 0.580, mpsF32: 0.578, mfaF16: 0.803, mfaF32: 0.695), | |
MatrixShape(B: nil, M: 4096, N: 4096, K: 40, false): .init( | |
mpsF16: 0.305, mpsF32: 0.342, mfaF16: 0.692, mfaF32: 0.579), | |
MatrixShape(B: nil, M: 4096, N: 40, K: 4096, false): .init( | |
mpsF16: 0.089, mpsF32: 0.090, mfaF16: 0.496, mfaF32: 0.414), | |
MatrixShape(B: 8, M: 4096, N: 1713, K: 40, true): .init( | |
mpsF16: 0.323, mpsF32: 0.352, mfaF16: 0.662, mfaF32: 0.544), | |
MatrixShape(B: 8, M: 4096, N: 40, K: 1713, true): .init( | |
mpsF16: 0.167, mpsF32: 0.161, mfaF16: 0.643, mfaF32: 0.565), | |
MatrixShape(B: 8, M: 4096, N: 92, K: 40, true): .init( | |
mpsF16: 0.115, mpsF32: 0.113, mfaF16: 0.491, mfaF32: 0.445), | |
MatrixShape(B: 8, M: 4096, N: 40, K: 92, true): .init( | |
mpsF16: 0.116, mpsF32: 0.118, mfaF16: 0.520, mfaF32: 0.428), | |
MatrixShape(B: 2, M: 4096, N: 1280, K: 320, false): .init( | |
mpsF16: 0.674, mpsF32: 0.703, mfaF16: 0.849, mfaF32: 0.740), | |
MatrixShape(B: 2, M: 4096, N: 320, K: 1280, false): .init( | |
mpsF16: 0.700, mpsF32: 0.710, mfaF16: 0.830, mfaF32: 0.720), | |
MatrixShape(B: 2, M: 1024, N: 640, K: 640, false): .init( | |
mpsF16: 0.576, mpsF32: 0.584, mfaF16: 0.804, mfaF32: 0.697), | |
MatrixShape(B: nil, M: 1024, N: 1024, K: 80, false): .init( | |
mpsF16: 0.202, mpsF32: 0.210, mfaF16: 0.569, mfaF32: 0.485), | |
MatrixShape(B: nil, M: 1024, N: 80, K: 1024, false): .init( | |
mpsF16: 0.107, mpsF32: 0.096, mfaF16: 0.530, mfaF32: 0.396), | |
MatrixShape(B: 8, M: 1024, N: 1713, K: 80, true): .init( | |
mpsF16: 0.472, mpsF32: 0.549, mfaF16: 0.748, mfaF32: 0.612), | |
MatrixShape(B: 8, M: 1024, N: 80, K: 1713, true): .init( | |
mpsF16: 0.129, mpsF32: 0.112, mfaF16: 0.668, mfaF32: 0.577), | |
MatrixShape(B: 8, M: 1024, N: 92, K: 80, true): .init( | |
mpsF16: 0.116, mpsF32: 0.117, mfaF16: 0.473, mfaF32: 0.433), | |
MatrixShape(B: 8, M: 1024, N: 80, K: 92, true): .init( | |
mpsF16: 0.093, mpsF32: 0.075, mfaF16: 0.431, mfaF32: 0.365), | |
MatrixShape(B: 2, M: 1024, N: 2560, K: 640, false): .init( | |
mpsF16: 0.693, mpsF32: 0.708, mfaF16: 0.848, mfaF32: 0.738), | |
MatrixShape(B: 2, M: 1024, N: 640, K: 2560, false): .init( | |
mpsF16: 0.691, mpsF32: 0.685, mfaF16: 0.822, mfaF32: 0.682), | |
MatrixShape(B: 2, M: 256, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.579, mpsF32: 0.526, mfaF16: 0.804, mfaF32: 0.712), | |
MatrixShape(B: nil, M: 256, N: 256, K: 160, false): .init( | |
mpsF16: 0.025, mpsF32: 0.027, mfaF16: 0.300, mfaF32: 0.242), | |
MatrixShape(B: nil, M: 256, N: 160, K: 256, false): .init( | |
mpsF16: 0.025, mpsF32: 0.029, mfaF16: 0.287, mfaF32: 0.234), | |
MatrixShape(B: 8, M: 256, N: 1713, K: 160, true): .init( | |
mpsF16: 0.452, mpsF32: 0.482, mfaF16: 0.750, mfaF32: 0.633), | |
MatrixShape(B: 8, M: 256, N: 160, K: 1713, true): .init( | |
mpsF16: 0.181, mpsF32: 0.157, mfaF16: 0.752, mfaF32: 0.536), | |
MatrixShape(B: 8, M: 256, N: 92, K: 160, true): .init( | |
mpsF16: 0.079, mpsF32: 0.071, mfaF16: 0.392, mfaF32: 0.337), | |
MatrixShape(B: 8, M: 256, N: 160, K: 92, true): .init( | |
mpsF16: 0.076, mpsF32: 0.076, mfaF16: 0.411, mfaF32: 0.316), | |
MatrixShape(B: 2, M: 256, N: 5120, K: 1280, false): .init( | |
mpsF16: 0.697, mpsF32: 0.691, mfaF16: 0.830, mfaF32: 0.717), | |
MatrixShape(B: 2, M: 256, N: 1280, K: 5120, false): .init( | |
mpsF16: 0.686, mpsF32: 0.580, mfaF16: 0.813, mfaF32: 0.636), | |
MatrixShape(B: 2, M: 64, N: 1280, K: 1280, false): .init( | |
mpsF16: 0.345, mpsF32: 0.335, mfaF16: 0.660, mfaF32: 0.614), | |
MatrixShape(B: nil, M: 64, N: 64, K: 160, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.030, mfaF32: 0.026), | |
MatrixShape(B: nil, M: 64, N: 160, K: 64, false): .init( | |
mpsF16: 0.002, mpsF32: 0.002, mfaF16: 0.044, mfaF32: 0.035), | |
MatrixShape(B: 8, M: 64, N: 1713, K: 160, true): .init( | |
mpsF16: 0.285, mpsF32: 0.289, mfaF16: 0.588, mfaF32: 0.469), | |
MatrixShape(B: 8, M: 64, N: 160, K: 1713, true): .init( | |
mpsF16: 0.233, mpsF32: 0.234, mfaF16: 0.430, mfaF32: 0.339), | |
MatrixShape(B: 8, M: 64, N: 92, K: 160, true): .init( | |
mpsF16: 0.019, mpsF32: 0.020, mfaF16: 0.194, mfaF32: 0.151), | |
MatrixShape(B: 8, M: 64, N: 160, K: 92, true): .init( | |
mpsF16: 0.020, mpsF32: 0.021, mfaF16: 0.197, mfaF32: 0.169), | |
MatrixShape(B: 2, M: 64, N: 5120, K: 1280, false): .init( | |
mpsF16: 0.576, mpsF32: 0.510, mfaF16: 0.794, mfaF32: 0.611), | |
MatrixShape(B: 2, M: 64, N: 1280, K: 5120, false): .init( | |
mpsF16: 0.475, mpsF32: 0.471, mfaF16: 0.684, mfaF32: 0.631), | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was using it to predict the latency of Stable Diffusion inference in Draw Things, last summer. I got results that were within a factor of 2 of the actual latency, which was amazing. It allowed me to get detailed, low-level insights into exactly where the performance bottlenecks were. And, to predict which optimizations would be worth my time implementing.
For example:
The data used to parameterize this (mostly the dimensions of each matrix in SDv1.4) were from NNC logs. I kept those gists private (because I was actually using this as the test case), but here is a snippet. You'd have to find a similar mechanism in whatever ML framework you are using, to understand which problem sizes to get GFLOPS for.
To build theoretical models of execution speed (rooflines), you need to know the hardware characteristics. Whether an operation is compute bound or memory bound, and how much compute/memory the profiled device has. GEMM is very easy to predict, you have M x N x K instructions (2 x M x N x K FLOPs). Divide that by a billion times the manufacturer's reported GFLOPS, and you have a semi-quantitative estimate of how long the GEMM takes. That's mostly what this script is about, just with something much more complex than a single GEMM. You would likely build your own script from scratch, by studying the code structure and algorithms in this GitHub Gist.