Skip to content

Instantly share code, notes, and snippets.

@JacopoMangiavacchi
Last active January 15, 2021 07:14
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JacopoMangiavacchi/ac6127f3a6c669112ce6c13b88eea350 to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/ac6127f3a6c669112ce6c13b88eea350 to your computer and use it in GitHub Desktop.
MLCompute Playground
import Foundation
import MLCompute
import PlaygroundSupport
let page = PlaygroundPage.current
page.needsIndefiniteExecution = true
let tensor1 = MLCTensor(descriptor: MLCTensorDescriptor(shape: [6, 1], dataType: .float32)!)
let tensor2 = MLCTensor(descriptor: MLCTensorDescriptor(shape: [6, 1], dataType: .float32)!)
let tensor3 = MLCTensor(descriptor: MLCTensorDescriptor(shape: [6, 1], dataType: .float32)!)
let buffer1: [Float] = [1, 2, 3, 4, 5, 6]
let buffer2: [Float] = [7, 8, 9, 10, 11, 12]
let buffer3: [Float] = [6, 5, 4, 3, 2, 1]
let data1 = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(buffer1), length: buffer1.count * MemoryLayout<Float>.size)
let data2 = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(buffer2), length: buffer2.count * MemoryLayout<Float>.size)
let data3 = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(buffer3), length: buffer3.count * MemoryLayout<Float>.size)
let g = MLCGraph()
let tensor1plus2 = g.node(with: MLCArithmeticLayer(operation: .add), sources: [tensor1, tensor2])
g.node(with: MLCArithmeticLayer(operation: .add), sources: [tensor1plus2!, tensor3])
let i = MLCInferenceGraph(graphObjects: [g])
i.addInputs(["data1" : tensor1, "data2" : tensor2, "data3" : tensor3])
i.compile(options: .debugLayers, device: MLCDevice())
i.execute(inputsData: ["data1" : data1, "data2" : data2, "data3" : data3],
batchSize: 0,
options: []) { (r, e, time) in
print("Error: \(String(describing: e))")
print("Result: \(String(describing: r))")
let buffer3 = UnsafeMutableRawPointer.allocate(byteCount: 6 * MemoryLayout<Float>.size, alignment: MemoryLayout<Float>.alignment)
r!.copyDataFromDeviceMemory(toBytes: buffer3, length: 6 * MemoryLayout<Float>.size, synchronizeWithDevice: false)
let float4Ptr = buffer3.bindMemory(to: Float.self, capacity: 6)
let float4Buffer = UnsafeBufferPointer(start: float4Ptr, count: 6)
print(Array(float4Buffer))
page.finishExecution()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment