Last active
August 1, 2023 14:47
-
-
Save philipturner/939d4ffda26e66f10a142c82d8d498e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// ContentView.swift | |
// SIMDFuturesA15 | |
// | |
// Created by Philip Turner on 6/9/23. | |
// | |
import SwiftUI | |
import Foundation | |
import QuartzCore | |
import MetalPerformanceShadersGraph | |
import Accelerate | |
struct ContentView: View { | |
var body: some View { | |
VStack { | |
Image(systemName: "globe") | |
.imageScale(.large) | |
.foregroundStyle(.tint) | |
Text({ () -> String in | |
testMFALibrary() | |
exit(0) | |
return "Hello, world!" | |
}()) | |
} | |
.padding() | |
} | |
} | |
// Load and profile the MFA library, compare against MPSGraph for 2048x2048x2048 | |
// square matrices of each precision. Report the number of GFLOPS achieved. | |
var matrixDim: Int = 2048 | |
var numTrials: Int = 5 | |
var numRepetitionsANE: Int = 5 | |
let verbose = false | |
func testMFALibrary() { | |
matrixDim = 256 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 512 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 768 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 1024 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: false, useANE: false, forceRunSync: false) | |
srand48(0) | |
profile(useHalf: true, useANE: false, forceRunSync: false) | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 1024 + 256 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 1024 + 512 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 1024 + 512 + 256 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 2048 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: false, useANE: false, forceRunSync: false) | |
srand48(0) | |
profile(useHalf: true, useANE: false, forceRunSync: false) | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
matrixDim = 4096 | |
numTrials = 3 | |
numRepetitionsANE = 2 | |
print() | |
print("GEMM dimensions: \(matrixDim)x\(matrixDim)x\(matrixDim)") | |
srand48(0) | |
profile(useHalf: true, useANE: true, forceRunSync: true) | |
} | |
func profile(useHalf: Bool, useANE: Bool, forceRunSync: Bool) { | |
enum Variant { | |
case blocked // fallback to ensure this works | |
case streamk | |
} | |
let variant = Variant.streamk | |
let device = MTLCreateSystemDefaultDevice()! | |
let commandQueue = device.makeCommandQueue()! | |
let libraryURL = Bundle.main.url( | |
forResource: "libMetalFlashAttention", withExtension: "metallib")! | |
let library = try! device.makeLibrary(URL: libraryURL) | |
// MARK: - Universal Constants | |
let constants = MTLFunctionConstantValues() | |
var device_architecture: UInt16 = 1 | |
var device_generation: UInt16 = 8 | |
var device_cores: UInt16 = 5 | |
var atomic_algorithm: UInt16 = 1 // Nanite | |
constants.setConstantValue(&device_architecture, type: .ushort, index: 1) | |
constants.setConstantValue(&device_generation, type: .ushort, index: 2) | |
constants.setConstantValue(&device_cores, type: .ushort, index: 3) | |
constants.setConstantValue(&atomic_algorithm, type: .ushort, index: 4) | |
// MARK: - Block Sizes | |
// Compiled with 64x64x16, 2x2 simds/threadgroup. | |
var MNK: UInt32 = .init(matrixDim) | |
constants.setConstantValue(&MNK, type: .uint, index: 100) | |
constants.setConstantValue(&MNK, type: .uint, index: 101) | |
constants.setConstantValue(&MNK, type: .uint, index: 102) | |
// MARK: - Internal | |
// For some reason, newer versions of Xcode require this to be set. | |
var _simdgroup_use_threadgroup: Bool = true | |
constants.setConstantValue( | |
&_simdgroup_use_threadgroup, type: .bool, index: 300) | |
// This should eventually be removed, so the user doesn't have to access the | |
// private API. | |
var simds_per_threadgroup: SIMD2<UInt16> = .init(2, 2) | |
constants.setConstantValue(&simds_per_threadgroup, type: .ushort2, index: 301) | |
var tg_size_mn: UInt16 = 4 | |
switch variant { | |
case .blocked: | |
constants.setConstantValue(&tg_size_mn, type: .ushort, index: 200) | |
case .streamk: | |
break | |
} | |
// MARK: - PSO Creation | |
var name: String | |
switch variant { | |
case .blocked: | |
name = useHalf ? "hgemm_blocked" : "sgemm_blocked" | |
case .streamk: | |
name = useHalf ? "hgemm" : "sgemm" | |
} | |
let function = try! library.makeFunction( | |
name: name, constantValues: constants) | |
let pipeline = try! device.makeComputePipelineState(function: function) | |
let elementSize = useHalf ? 2 : 4 | |
var matrices: [MTLBuffer] = [] | |
for _ in 0..<3 { | |
let matrixSize = matrixDim * matrixDim * elementSize | |
let buffer = device.makeBuffer(length: matrixSize)! | |
if useHalf { | |
let casted = buffer.contents().assumingMemoryBound(to: Float16.self) | |
matrixInitRandom( | |
matrix: casted, dim1: .init(matrixDim), dim2: .init(matrixDim) | |
) { | |
Float16($0) | |
} | |
} else { | |
let casted = buffer.contents().assumingMemoryBound(to: Float32.self) | |
matrixInitRandom( | |
matrix: casted, dim1: .init(matrixDim), dim2: .init(matrixDim) | |
) { | |
Float32($0) | |
} | |
} | |
matrices.append(buffer) | |
} | |
// Profile GFLOPS on MFA. | |
var numThreadgroups: MTLSize | |
var threadgroupSize: MTLSize | |
switch variant { | |
case .blocked: | |
let tg_width: Int = 16 | |
let execution_width = Int(tg_size_mn) * tg_width | |
let grid_width = matrixDim / execution_width | |
numThreadgroups = MTLSizeMake(grid_width, grid_width, 1) | |
threadgroupSize = MTLSizeMake(tg_width, tg_width, 1) | |
case .streamk: | |
numThreadgroups = MTLSizeMake((matrixDim / 64), (matrixDim / 64), 1) | |
threadgroupSize = MTLSizeMake(128, 1, 1) | |
} | |
let precisionRepr = useHalf ? "f16" : "f32" | |
let precisionReprMPS = useHalf ? ( | |
useANE ? "f16 (either ANE or GPU tensor cores)" : "f16") : "f32" | |
var commandBuffers: [MTLCommandBuffer] = [] | |
for _ in 0..<numTrials { | |
let commandBuffer = commandQueue.makeCommandBuffer()! | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setBuffer(matrices[0], offset: 0, index: 0) | |
encoder.setBuffer(matrices[1], offset: 0, index: 1) | |
encoder.setBuffer(matrices[2], offset: 0, index: 2) | |
encoder.dispatchThreadgroups( | |
numThreadgroups, | |
threadsPerThreadgroup: threadgroupSize) | |
encoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffers.append(commandBuffer) | |
} | |
print() | |
print("Metal FlashAttention: '\(precisionRepr)'") | |
if verbose { | |
print("Variant: '\(name)'") | |
} | |
var minLatency: Double = 1000 | |
for commandBuffer in commandBuffers { | |
commandBuffer.waitUntilCompleted() | |
let latencySeconds = commandBuffer.gpuEndTime - commandBuffer.gpuStartTime | |
let latencyMicroseconds = Int(latencySeconds / 1e-6) | |
if verbose { | |
print("us: \(latencyMicroseconds)") | |
} | |
minLatency = min(latencySeconds, minLatency) | |
} | |
let floatOps = matrixDim * matrixDim * matrixDim * 2 | |
let flops = Double(floatOps) / Double(minLatency) | |
let gflops = Int(flops / 1e9) | |
print("GFLOPS: \(gflops)") | |
let matrixSize = matrixDim * matrixDim * elementSize | |
let outputsCopy = malloc(matrixSize)! | |
defer { free(outputsCopy) } | |
memcpy(outputsCopy, matrices[2].contents(), matrixSize) | |
if verbose { | |
if useHalf { | |
let ptr = outputsCopy.assumingMemoryBound(to: Float16.self) | |
print("First element of MFA output is currently: \(ptr[0])") | |
print("65536th element of MFA output is currently: \(ptr[65536])") | |
} else { | |
let ptr = outputsCopy.assumingMemoryBound(to: Float32.self) | |
print("First element of MFA output is currently: \(ptr[0])") | |
print("65536th element of MFA output is currently: \(ptr[65536])") | |
} | |
} | |
// Now, profile against MPSGraph and compare correctness. | |
// MARK: - GEMM_initialize | |
let graph = MPSGraph() | |
class GPUContext { | |
var device: MTLDevice | |
var commandQueue: MTLCommandQueue | |
// no library because we only call into MPS | |
var graphDevice: MPSGraphDevice | |
static let shared = GPUContext() | |
init() { | |
let devices = [MTLCreateSystemDefaultDevice()!] | |
var bestDevice: MTLDevice? | |
for device in devices { | |
if device.supportsFamily(.metal3) { | |
// if bestDevice != nil && device.isLowPower { | |
// continue | |
// } | |
bestDevice = device | |
} | |
} | |
guard let bestDevice else { | |
preconditionFailure("Could not locate a Metal 3 GPU.") | |
} | |
self.device = bestDevice | |
self.commandQueue = device.makeCommandQueue()! | |
self.graphDevice = MPSGraphDevice(mtlDevice: device) | |
} | |
} | |
func reduceShape(_ array: [NSNumber]) -> Int { | |
return array.reduce(Int(1)) { lhs, rhs in | |
return lhs * rhs.intValue | |
} | |
} | |
class Tensor { | |
var tensor: MPSGraphTensor | |
var tensorData: MPSGraphTensorData | |
var buffer: MTLBuffer | |
var bufferSize: Int | |
init(tensor: MPSGraphTensor, dataType: MPSDataType) { | |
self.tensor = tensor | |
var dataSize: Int | |
switch dataType { | |
case .float16: dataSize = 2 | |
case .float32: dataSize = 4 | |
default: | |
fatalError("Data type with raw value '\(dataType.rawValue)' unsupported.") | |
} | |
self.bufferSize = reduceShape(tensor.shape!) * dataSize | |
let padding: Int = 512 | |
let paddedSize = ~(padding - 1) & (bufferSize + padding - 1) | |
let device = GPUContext.shared.device | |
self.buffer = device.makeBuffer( | |
length: paddedSize, options: .storageModeShared)! | |
self.tensorData = MPSGraphTensorData( | |
buffer, shape: tensor.shape!, dataType: dataType) | |
} | |
func shapedType() -> MPSGraphShapedType { | |
MPSGraphShapedType(shape: tensor.shape!, dataType: tensorData.dataType) | |
} | |
} | |
func makeShape(_ array: Int64...) -> [NSNumber] { | |
return array.map(NSNumber.init(value:)) | |
} | |
let isHalf: Int8 = useHalf ? 1 : 0 | |
let m: Int64 = .init(matrixDim) | |
let n: Int64 = .init(matrixDim) | |
let k: Int64 = .init(matrixDim) | |
let dataType: MPSDataType = (isHalf != 0) ? .float16 : .float32 | |
let _lhs = graph.placeholder( | |
shape: makeShape(m, k), dataType: dataType, name: "lhs") | |
let _rhs = graph.placeholder( | |
shape: makeShape(k, n), dataType: dataType, name: "rhs") | |
let _out = graph.matrixMultiplication( | |
primary: _lhs, secondary: _rhs, name: "out") | |
precondition(_out.shape!.first! == _lhs.shape!.first!) | |
precondition(_out.shape!.last! == _rhs.shape!.last!) | |
let lhs = Tensor(tensor: _lhs, dataType: dataType) | |
let rhs = Tensor(tensor: _rhs, dataType: dataType) | |
let out = Tensor(tensor: _out, dataType: dataType) | |
let feedsDict = [ | |
lhs.tensor: lhs, | |
rhs.tensor: rhs | |
] | |
let targetsDict = [ | |
out.tensor: out | |
] | |
func mapFeeds( | |
_ dict: [MPSGraphTensor: Tensor] | |
) -> [MPSGraphTensor: MPSGraphShapedType] { | |
var output: [MPSGraphTensor: MPSGraphShapedType] = [:] | |
for (key, value) in dict { | |
output[key] = value.shapedType() | |
} | |
return output | |
} | |
let feeds = mapFeeds(feedsDict) | |
func makeGraphDesc() -> MPSGraphCompilationDescriptor { | |
let desc = MPSGraphCompilationDescriptor() | |
desc.optimizationLevel = useANE ? .level1 : .level0 | |
desc.optimizationProfile = .performance | |
return desc | |
} | |
let graphExec = graph.compile( | |
with: GPUContext.shared.graphDevice, | |
feeds: feeds, | |
targetTensors: [out.tensor], | |
targetOperations: nil, | |
compilationDescriptor: makeGraphDesc()) | |
// MARK: - GEMM_setInputs | |
struct Pair { | |
var data: UnsafeRawPointer | |
var tensor: Tensor | |
} | |
let pairs: [Pair] = [ | |
Pair(data: .init(matrices[0].contents()), tensor: lhs), | |
Pair(data: .init(matrices[1].contents()), tensor: rhs) | |
] | |
for pair in pairs { | |
let src = pair.data | |
let dst = pair.tensor.buffer.contents() | |
let len = pair.tensor.bufferSize | |
memcpy(dst, src, len) | |
} | |
// MARK: - GEMM_runAsync | |
func mapArguments( | |
_ dict: [MPSGraphTensor: Tensor], | |
order: [MPSGraphTensor] | |
) -> [MPSGraphTensorData] { | |
return order.map { dict[$0]!.tensorData } | |
} | |
func dispatchCommandBuffer( | |
gpuTime: UnsafeMutablePointer<Double>?, | |
checkError: Int8, | |
wait: Int8, | |
graphExec: MPSGraphExecutable, | |
feedsDict: [MPSGraphTensor: Tensor], | |
targetsDict: [MPSGraphTensor: Tensor], | |
synchronized: inout Bool | |
) { | |
let commandQueue = GPUContext.shared.commandQueue | |
if forceRunSync { | |
let start = CACurrentMediaTime() | |
for _ in 0..<numRepetitionsANE - 1 { | |
graphExec.runAsync( | |
with: commandQueue, | |
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!), | |
results: mapArguments(targetsDict, order: graphExec.targetTensors!), | |
executionDescriptor: nil) | |
} | |
graphExec.run( | |
with: commandQueue, | |
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!), | |
results: mapArguments(targetsDict, order: graphExec.targetTensors!), | |
executionDescriptor: nil) | |
let end = CACurrentMediaTime() | |
gpuTime!.pointee = (end - start) / Double(numRepetitionsANE) | |
synchronized = true | |
} else { | |
let commandBuffer = commandQueue.makeCommandBuffer()! | |
let mpsCommandBuffer = MPSCommandBuffer(commandBuffer: commandBuffer) | |
let desc = MPSGraphExecutableExecutionDescriptor() | |
if gpuTime != nil || checkError != 0 { | |
desc.completionHandler = { _, error in | |
if checkError != 0 { | |
if let error { | |
print(error.localizedDescription) | |
fatalError(error.localizedDescription) | |
} | |
precondition(commandBuffer.status == .completed, "Unexpected status.") | |
} | |
if let gpuTime { | |
let startTime = commandBuffer.gpuStartTime | |
let endTime = commandBuffer.gpuEndTime | |
gpuTime.pointee = endTime - startTime | |
} | |
} | |
} | |
graphExec.encode( | |
to: mpsCommandBuffer, | |
inputs: mapArguments(feedsDict, order: graphExec.feedTensors!), | |
results: mapArguments(targetsDict, order: graphExec.targetTensors!), | |
executionDescriptor: desc) | |
commandBuffer.commit() | |
if wait != 0 { | |
commandBuffer.waitUntilCompleted() | |
synchronized = true | |
} else { | |
synchronized = false | |
} | |
} | |
} | |
let timesBuffer = malloc(numTrials * 8)!.assumingMemoryBound(to: Double.self) | |
defer { free(timesBuffer) } | |
var synchronized: Bool = true | |
for i in 0..<numTrials { | |
let wait: Int8 = (i == numTrials - 1) ? 1 : 0 | |
let gpuTime = timesBuffer + i | |
dispatchCommandBuffer( | |
gpuTime: gpuTime, | |
checkError: 1, | |
wait: wait, | |
graphExec: graphExec, | |
feedsDict: feedsDict, | |
targetsDict: targetsDict, | |
synchronized: &synchronized) | |
} | |
// Report performance of MPS. | |
do { | |
if verbose { | |
print() | |
} | |
print("Metal Performance Shaders: '\(precisionReprMPS)'") | |
var minLatency: Double = 1000 | |
for i in 0..<numTrials { | |
let latencySeconds = timesBuffer[i] | |
let latencyMicroseconds = Int(latencySeconds / 1e-6) | |
if verbose { | |
print("us: \(latencyMicroseconds)") | |
} | |
minLatency = min(latencySeconds, minLatency) | |
} | |
let floatOps = matrixDim * matrixDim * matrixDim * 2 | |
let flops = Double(floatOps) / Double(minLatency) | |
let gflops = Int(flops / 1e9) | |
print("GFLOPS: \(gflops)") | |
if verbose { | |
if useHalf { | |
let ptr = outputsCopy.assumingMemoryBound(to: Float16.self) | |
print("First element of MFA output is currently: \(ptr[0])") | |
print("65536th element of MFA output is currently: \(ptr[65536])") | |
} else { | |
let ptr = outputsCopy.assumingMemoryBound(to: Float32.self) | |
print("First element of MFA output is currently: \(ptr[0])") | |
print("65536th element of MFA output is currently: \(ptr[65536])") | |
} | |
} | |
} | |
// MARK: - GEMM_compare | |
// Compute Euclidean distance, which tolerates small differences due to | |
// numerical error. | |
func euclideanDistance( | |
n: Int, | |
x: UnsafePointer<Float>, | |
y: UnsafePointer<Float>, | |
difference: UnsafeMutablePointer<Float> | |
) -> Float { | |
memcpy(difference, x, n * 4) | |
var n_copy = Int32(n) | |
var a = Float(-1) | |
var inc = Int32(1) | |
var inc_copy = inc | |
// Find x + (-1 * y) | |
saxpy_( | |
&n_copy, &a, UnsafeMutablePointer(mutating: y), &inc, difference, | |
&inc_copy) | |
// Find ||x - y|| | |
return Float(snrm2_(&n_copy, difference, &inc)) | |
} | |
func printComparisonError( | |
distance: Float, | |
tolerance: Float, | |
n: Int, | |
x: UnsafePointer<Float>, | |
y: UnsafePointer<Float>, | |
difference: UnsafePointer<Float> | |
) { | |
// How many elements to print before stopping. | |
let diagnosticElems = 16 | |
print() | |
print(""" | |
Vectors did not match: Euclidean distance '\(distance)' > \ | |
tolerance '\(tolerance)'. | |
""") | |
var failedElements = 0 | |
for i in 0..<Int(n) { | |
if abs(x[i] - y[i]) > 0.001 * distance || x[i].isNaN || y[i].isNaN { | |
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])") | |
failedElements += 1 | |
if failedElements >= diagnosticElems { | |
break | |
} | |
} | |
} | |
if failedElements == 0 { | |
print("Could not find elements with enough separation.") | |
for index in 0..<Int(diagnosticElems) { | |
let i = (index == 0) ? 0 : Int.random(in: 0..<n) | |
print("Element \(i): \(x[i]) - \(y[i]) = \(difference[i])") | |
} | |
} | |
} | |
do { | |
let _lhs = matrices[0].contents() | |
let _rhs = matrices[1].contents() | |
let _out = outputsCopy | |
let printError: Int8 = 1 | |
// Total rounding error scales with the square root of K and linearly with | |
// M/N. Account for this when checking the output matrix. | |
// TODO: Does transposed multiplication need a different heuristic? | |
func matches( | |
_ input1: UnsafeRawPointer, | |
_ input2: Tensor, | |
useK: Bool | |
) -> Bool { | |
let dataType = input2.tensorData.dataType | |
switch dataType { | |
case .float16: break | |
case .float32: break | |
default: | |
fatalError( | |
"Data type with raw value '\(dataType.rawValue)' unsupported.") | |
} | |
// Do not mutate 'x' and 'y' after filling them. | |
var bufferSize: Int | |
var n: Int | |
var x: UnsafeMutablePointer<Float> | |
var y: UnsafeMutablePointer<Float> | |
if dataType == .float16 { | |
let bufferSize_f16 = input2.bufferSize | |
let bufferSize_f32 = bufferSize_f16 * 2 | |
bufferSize = bufferSize_f32 | |
n = bufferSize_f32 / MemoryLayout<Float>.stride | |
x = .allocate(capacity: n) | |
y = .allocate(capacity: n) | |
// Partially sourced from: | |
// https://github.com/hollance/TensorFlow-iOS-Example/blob/master/VoiceMetal/VoiceMetal/Float16.swift | |
func copy(dst: UnsafeMutableRawPointer, src: UnsafeMutableRawPointer) { | |
var bufferFloat16 = vImage_Buffer( | |
data: src, height: 1, width: UInt(n), rowBytes: bufferSize_f16) | |
var bufferFloat32 = vImage_Buffer( | |
data: dst, height: 1, width: UInt(n), rowBytes: bufferSize_f32) | |
let error = vImageConvert_Planar16FtoPlanarF( | |
&bufferFloat16, &bufferFloat32, 0) | |
if error != kvImageNoError { | |
fatalError( | |
"Encountered error code \(error) while converting F16 to F32.") | |
} | |
} | |
copy(dst: x, src: .init(mutating: input1)) | |
copy(dst: y, src: input2.buffer.contents()) | |
} else { | |
bufferSize = input2.bufferSize | |
n = bufferSize / MemoryLayout<Float>.stride | |
x = .allocate(capacity: n) | |
y = .allocate(capacity: n) | |
memcpy(x, input1, bufferSize) | |
memcpy(y, input2.buffer.contents(), bufferSize) | |
} | |
defer { | |
x.deallocate() | |
y.deallocate() | |
} | |
let difference: UnsafeMutablePointer<Float> = .allocate(capacity: n) | |
defer { | |
difference.deallocate() | |
} | |
let distance = euclideanDistance(n: n, x: x, y: y, difference: difference) | |
// How many elements to print before stopping. | |
var tolerance: Float | |
if useK { | |
let mk = lhs.tensor.shape! | |
let kn = rhs.tensor.shape! | |
precondition(mk[1] == kn[0], "K dimension does not match.") | |
let M = mk[0].intValue | |
let N = kn[1].intValue | |
let K = mk[1].intValue | |
let expectedDeviation = Float(M * N) * sqrt(Float(K)) | |
if input2.tensorData.dataType == .float32 { | |
tolerance = max(0.001, 3e-7 * expectedDeviation) | |
// tolerance = max(0.001, 3e-9 * expectedDeviation) | |
} else { | |
tolerance = max(0.01, 5e-3 * expectedDeviation) | |
// tolerance = max(0.01, 5e-4 * expectedDeviation) | |
} | |
} else { | |
if input2.tensorData.dataType == .float32 { | |
tolerance = 0.001 | |
} else { | |
tolerance = 0.01 | |
} | |
} | |
if useK { | |
if verbose { | |
print() | |
} | |
print("MFA vs MPS Euclidean Distance: \(distance)") | |
} | |
if distance < tolerance { | |
return true | |
} else { | |
if printError != 0 { | |
printComparisonError( | |
distance: distance, | |
tolerance: tolerance, | |
n: n, | |
x: x, | |
y: y, | |
difference: difference) | |
} | |
return false | |
} | |
} | |
if matches(_lhs, lhs, useK: false), | |
matches(_rhs, rhs, useK: false), | |
matches(_out, out, useK: true) { | |
// return 1 | |
} else { | |
if printError == 1 { | |
func printShape(_ tensor: Tensor, name: String) { | |
print("Tensor '\(name)' has shape \(tensor.tensor.shape!).") | |
} | |
printShape(lhs, name: "lhs") | |
printShape(rhs, name: "rhs") | |
printShape(out, name: "out") | |
} | |
// return 0 | |
} | |
} | |
} | |
func matrixInitRandom<T: BinaryFloatingPoint>( | |
matrix: UnsafeMutablePointer<T>, dim1: Int, dim2: Int, | |
_ initialize: (Double) -> T | |
) { | |
for i in 0..<dim1 * dim2 { | |
let element = initialize(drand48()) | |
matrix[i] = element | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment