Skip to content

Instantly share code, notes, and snippets.

@xrq-phys
Created December 25, 2020 18:14
Show Gist options
  • Save xrq-phys/c9d198dcd97647f73c0092733b77dec5 to your computer and use it in GitHub Desktop.
Save xrq-phys/c9d198dcd97647f73c0092733b77dec5 to your computer and use it in GitHub Desktop.
Very simple example of an ML Compute matrix multiplication. Naming convention somehow violated.
import Foundation
import MLCompute
import PlaygroundSupport
let iPage = PlaygroundPage.current
iPage.needsIndefiniteExecution = true
/*
* Apple says MLCMatMulLayer does a ``batch matrix multiplication''
* but didn't make clear its meaning.
* According to my try-and-error, it seems to mean that MLCMatMulLayer
* broadcasts GEMM operations along the first axis.
* (Meaning the 2nd and 3rd axes stores matrices to be multiplied.)
*/
let tA = MLCTensor(shape: [1, 2, 2], dataType: .float32)
let tB = MLCTensor(shape: [1, 2, 2], dataType: .float32)
let tC = MLCTensor(shape: [1, 2, 2], dataType: .float32)
let bufA: [Float] = [1, 2, 3, 4]
let bufB: [Float] = [1, 2, 3, 4]
let bufC: [Float] = [1, 1, 1, 1]
let datA = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufA),
length: bufA.count * MemoryLayout<Float>.size)
let datB = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufB),
length: bufB.count * MemoryLayout<Float>.size)
let datC = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufC),
length: bufC.count * MemoryLayout<Float>.size)
let iGraph = MLCGraph()
let tAB = iGraph.node(with: MLCMatMulLayer(descriptor: MLCMatMulDescriptor())!,
sources: [tA, tB])
iGraph.node(with: MLCArithmeticLayer(operation: .add), sources: [tAB!, tC])
let iPlan = MLCInferenceGraph(graphObjects: [iGraph])
iPlan.addInputs(["A": tA, "B": tB, "C": tC])
iPlan.compile(options: .debugLayers, device: MLCDevice())
iPlan.execute(inputsData: ["A": datA, "B": datB, "C": datC],
batchSize: 0,
options: []) { (r, e, time) in
print("Error: \(String(describing: e))")
print("Result: \(String(describing: r))")
let bufO = UnsafeMutableRawPointer.allocate(byteCount: 4 * MemoryLayout<Float>.size,
alignment: MemoryLayout<Float>.alignment)
r!.copyDataFromDeviceMemory(toBytes: bufO,
length: 4 * MemoryLayout<Float>.size,
synchronizeWithDevice: false)
let outArray = bufO.bindMemory(to: Float.self, capacity: 4)
let outArrayDat = UnsafeBufferPointer(start: outArray, count: 4)
print(Array(outArrayDat))
iPage.finishExecution()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment