Skip to content

Instantly share code, notes, and snippets.

@malfet
Created May 31, 2024 13:47
Show Gist options
  • Save malfet/25bcbce305e7425acf8616c9d5517652 to your computer and use it in GitHub Desktop.
Save malfet/25bcbce305e7425acf8616c9d5517652 to your computer and use it in GitHub Desktop.
import Metal
import MetalPerformanceShadersGraph
func calculateExpMetal(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int, fastMathEnabled: Bool = false) {
let shader_source = """
#include <metal_stdlib>
using namespace metal;
kernel void do_exp(constant float *input [[buffer(0)]],
device float *output [[buffer(1)]],
uint thread_index [[thread_position_in_grid]]) {
output[thread_index] = exp(input[thread_index]);
}
"""
let options = MTLCompileOptions()
options.languageVersion = .version3_1
options.fastMathEnabled = fastMathEnabled
let library = try! device.makeLibrary(source:shader_source, options:options)
guard let mfunc = library.makeFunction(name: "do_exp") else { fatalError("Can't find function") }
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
computeEncoder.setBuffer(ibuf, offset:0, index: 0)
computeEncoder.setBuffer(obuf, offset:0, index: 1)
computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
computeEncoder.endEncoding()
cmdBuffer.commit()
cmdBuffer.waitUntilCompleted()
}
func calculateExpMPS(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int) {
let graph = MPSGraph()
let inputPlaceholder = graph.placeholder(shape: [nelem as NSNumber], dataType: .float32, name: nil)
let expNode = graph.exponent(with: inputPlaceholder, name: nil)
let mpsInputBuffer = MPSGraphTensorData(ibuf, shape: [nelem as NSNumber], dataType: .float32)
let mpsOutputBuffer = MPSGraphTensorData(obuf, shape: [nelem as NSNumber], dataType: .float32)
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer], targetOperations: nil, resultsDictionary: [expNode: mpsOutputBuffer])
}
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")
let nelem = 256;
guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
let ibuf_data = ibuf.contents().assumingMemoryBound(to: Float.self)
for i in 0..<nelem {
ibuf_data[i] = log(Float(i)*0.1 + 0.1)
}
guard let obuf_fast = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let obuf_prec = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let obuf_mps = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
calculateExpMPS(device: device, ibuf: ibuf, obuf: obuf_mps, nelem: nelem)
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_fast, nelem: nelem, fastMathEnabled: true)
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_prec, nelem: nelem, fastMathEnabled: false)
let obuf_fast_data = obuf_fast.contents().assumingMemoryBound(to: Float.self)
let obuf_prec_data = obuf_prec.contents().assumingMemoryBound(to: Float.self)
let obuf_mps_data = obuf_mps.contents().assumingMemoryBound(to: Float.self)
for i in 0..<100 {
let cpu_exp = exp(ibuf_data[i])
let fast_prec_diff = obuf_fast_data[i] - obuf_prec_data[i]
let mps_prec_diff = obuf_mps_data[i] - obuf_prec_data[i]
let prec_cpu_diff = obuf_prec_data[i] - cpu_exp
print("exp(\(ibuf_data[i])) = \(cpu_exp) cpu_prec_diff = \(prec_cpu_diff) fast vs prec diff = \(fast_prec_diff) mps diff = \(mps_prec_diff)")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment