Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JacopoMangiavacchi/00207a942dcc1200bdc4cd1a350d95c5 to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/00207a942dcc1200bdc4cd1a350d95c5 to your computer and use it in GitHub Desktop.
private func evaluateGraph(log: (String) -> Void) {
let testingSample = testDataX!.count / imageSize
let testingBatches = testingSample / batchSize
inferenceGraph = MLCInferenceGraph(graphObjects: [graph])
inferenceGraph.addInputs(["image" : inputTensor])
inferenceGraph.compile(options: [], device: device)
// TESTING LOOP FOR A FULL EPOCH ON TESTING DATA
var match = 0
for batch in 0..<testingBatches {
let xData = testDataX!.withUnsafeBufferPointer { pointer in
MLCTensorData(immutableBytesNoCopy: pointer.baseAddress!.advanced(by: batch * imageSize * batchSize),
length: batchSize * imageSize * MemoryLayout<Float>.size)
}
inferenceGraph.execute(inputsData: ["image" : xData],
batchSize: batchSize,
options: [.synchronous]) { [self] (r, e, time) in
// print("Batch \(batch) Error: \(String(describing: e))")
let bufferOutput = UnsafeMutableRawPointer.allocate(byteCount: batchSize * numberOfClasses * MemoryLayout<Float>.size, alignment: MemoryLayout<Float>.alignment)
r!.copyDataFromDeviceMemory(toBytes: bufferOutput, length: batchSize * numberOfClasses * MemoryLayout<Float>.size, synchronizeWithDevice: false)
let float4Ptr = bufferOutput.bindMemory(to: Float.self, capacity: batchSize * numberOfClasses)
let float4Buffer = UnsafeBufferPointer(start: float4Ptr, count: batchSize * numberOfClasses)
let batchOutputArray = Array(float4Buffer)
for i in 0..<batchSize {
let batchStartingPoint = i * numberOfClasses
let predictionStartingPoint = (i * numberOfClasses) + (batch * batchSize * numberOfClasses)
let sampleOutputArray = Array(batchOutputArray[batchStartingPoint..<(batchStartingPoint + numberOfClasses)])
let predictionArray = Array(testDataY![predictionStartingPoint..<(predictionStartingPoint + numberOfClasses)])
let prediction = argmaxDecoding(sampleOutputArray)
let label = oneHotDecoding(predictionArray)
if prediction == label {
match += 1
}
// print("\(i + (batch * batchSize)) -> Prediction: \(prediction) Label: \(label)")
}
}
}
let accuracy = Float(match) / Float(testingSample)
log("Test Accuracy = \(accuracy) %")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment