Skip to content

Instantly share code, notes, and snippets.

@mhamilt
Created August 23, 2019 11:00
Show Gist options
  • Save mhamilt/0cd17da0480e86e4e6a89b841b2250b4 to your computer and use it in GitHub Desktop.
Save mhamilt/0cd17da0480e86e4e6a89b841b2250b4 to your computer and use it in GitHub Desktop.
Metal Compute Shader

Paralell Data Processing with Metal

Build Instructions

Create a Swift Command Line Tool in Xcode. You can then copy and paste the below code into the .swift and .metal files respectively

Source

main.swift

/*
 * Command line Metal Compute Shader for data processing
 */

import Metal
import Foundation
//------------------------------------------------------------------------------
let count = 10_000_000
let elementsPerSum = 10_000
//------------------------------------------------------------------------------
typealias DataType = CInt // Data type, has to be the same as in the shader
//------------------------------------------------------------------------------
let device = MTLCreateSystemDefaultDevice()!
let library = device.makeDefaultLibrary()!
let parsum = library.makeFunction(name: "parsum")!
let pipeline = try! device.makeComputePipelineState(function: parsum)
//------------------------------------------------------------------------------
// Our data, randomly generated:
var data = (0..<count).map{ _ in DataType(arc4random_uniform(100)) }
var dataCount = CUnsignedInt(count)
var elementsPerSumC = CUnsignedInt(elementsPerSum)
// Number of individual results = count / elementsPerSum (rounded up):
let resultsCount = (count + elementsPerSum - 1) / elementsPerSum
//------------------------------------------------------------------------------
// Our data in a buffer (copied):
let dataBuffer = device.makeBuffer(bytes: &data, length: MemoryLayout<DataType>.stride * count, options: [])!
// A buffer for individual results (zero initialized)
let resultsBuffer = device.makeBuffer(length: MemoryLayout<DataType>.stride * resultsCount, options: [])!
// Our results in convenient form to compute the actual result later:
let pointer = resultsBuffer.contents().bindMemory(to: DataType.self, capacity: resultsCount)
let results = UnsafeBufferPointer<DataType>(start: pointer, count: resultsCount)
//------------------------------------------------------------------------------
let queue = device.makeCommandQueue()!
let cmds = queue.makeCommandBuffer()!
let encoder = cmds.makeComputeCommandEncoder()!
//------------------------------------------------------------------------------
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(dataBuffer, offset: 0, index: 0)
encoder.setBytes(&dataCount, length: MemoryLayout<CUnsignedInt>.size, index: 1)
encoder.setBuffer(resultsBuffer, offset: 0, index: 2)
encoder.setBytes(&elementsPerSumC, length: MemoryLayout<CUnsignedInt>.size, index: 3)
//------------------------------------------------------------------------------
// We have to calculate the sum `resultCount` times => amount of threadgroups is `resultsCount` / `threadExecutionWidth` (rounded up) because each threadgroup will process `threadExecutionWidth` threads
let threadgroupsPerGrid = MTLSize(width: (resultsCount + pipeline.threadExecutionWidth - 1) / pipeline.threadExecutionWidth, height: 1, depth: 1)

// Here we set that each threadgroup should process `threadExecutionWidth` threads, the only important thing for performance is that this number is a multiple of `threadExecutionWidth` (here 1 times)
let threadsPerThreadgroup = MTLSize(width: pipeline.threadExecutionWidth, height: 1, depth: 1)
//------------------------------------------------------------------------------
encoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
encoder.endEncoding()
//------------------------------------------------------------------------------
var start, end : UInt64
var result : DataType = 0
//------------------------------------------------------------------------------
start = mach_absolute_time()
cmds.commit()
cmds.waitUntilCompleted()
for elem in results {
    result += elem
}

end = mach_absolute_time()
//------------------------------------------------------------------------------
print("Metal result: \(result), time: \(Double(end - start) / Double(NSEC_PER_SEC))")
//------------------------------------------------------------------------------
result = 0

start = mach_absolute_time()
data.withUnsafeBufferPointer { buffer in
    for elem in buffer {
        result += elem
    }
}
end = mach_absolute_time()

print("CPU result: \(result), time: \(Double(end - start) / Double(NSEC_PER_SEC))")
//------------------------------------------------------------------------------

parsum.metal

#include <metal_stdlib>
using namespace metal;

typedef unsigned int uint;
typedef int DataType;

kernel void parsum(const device DataType* data [[ buffer(0) ]],
                   const device uint& dataLength [[ buffer(1) ]],
                   device DataType* sums [[ buffer(2) ]],
                   const device uint& elementsPerSum [[ buffer(3) ]],
                   
                   const uint tgPos [[ threadgroup_position_in_grid ]],
                   const uint tPerTg [[ threads_per_threadgroup ]],
                   const uint tPos [[ thread_position_in_threadgroup ]]) {
    
    uint resultIndex = tgPos * tPerTg + tPos;
    
    uint dataIndex = resultIndex * elementsPerSum; // Where the summation should begin
    uint endIndex = dataIndex + elementsPerSum < dataLength ? dataIndex + elementsPerSum : dataLength; // The index where summation should end
    
    for (; dataIndex < endIndex; dataIndex++)
        sums[resultIndex] += data[dataIndex];
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment