Skip to content

Instantly share code, notes, and snippets.

@malfet
Created May 7, 2024 00:47
Show Gist options
  • Save malfet/10c000e868964b82318ab1b57e4f34e0 to your computer and use it in GitHub Desktop.
Save malfet/10c000e868964b82318ab1b57e4f34e0 to your computer and use it in GitHub Desktop.
Check if `nextafter(0.0, 1.0)` is greater than zero on Metal device
import Metal
let shader_source = """
#include <metal_stdlib>
using namespace metal;
kernel void nextafter_pred(device float *data [[buffer(0)]],
device bool *pred [[buffer(1)]],
uint thread_index [[thread_position_in_grid]]) {
data[thread_index] = nextafter(float(thread_index) - 8.0, 1e4);
pred[thread_index] = data[thread_index] > 0.0;
}
"""
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")
let options = MTLCompileOptions()
options.languageVersion = .version3_1
options.fastMathEnabled = false
let library = try! device.makeLibrary(source:shader_source, options:options)
guard let mfunc = library.makeFunction(name: "nextafter_pred") else { fatalError("Can't find function") }
let nelem = 256;
guard let dbuf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let pbuf = device.makeBuffer(length:nelem * MemoryLayout<Bool>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
computeEncoder.setBuffer(dbuf, offset:0, index: 0)
computeEncoder.setBuffer(pbuf, offset:0, index: 1)
computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
computeEncoder.endEncoding()
cmdBuffer.commit()
cmdBuffer.waitUntilCompleted()
let float_data = dbuf.contents().assumingMemoryBound(to: Float.self)
let bool_data = pbuf.contents().assumingMemoryBound(to: Bool.self)
for i in 0..<16 {
print("\(i): \(float_data[i]) >0 is \(bool_data[i])")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment