Skip to content

Instantly share code, notes, and snippets.

@malfet
Created January 9, 2024 02:29
Show Gist options
  • Save malfet/029aadee65629bcfcc7285c608a3bd79 to your computer and use it in GitHub Desktop.
Save malfet/029aadee65629bcfcc7285c608a3bd79 to your computer and use it in GitHub Desktop.
Swift example that runs matrix multiplicaiton on MPS
import MetalPerformanceShadersGraph
let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
graph.encode(to: cmdBuf, feeds: [:], targetOperations: nil, resultsDictionary: [z:td], executionDescriptor: nil)
cmdBuf.commit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment