Skip to content

Instantly share code, notes, and snippets.

@JacopoMangiavacchi
Created April 21, 2020 19:10
Show Gist options
  • Save JacopoMangiavacchi/fc0ebf6074a0b4ddd2c8ee1993e36245 to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/fc0ebf6074a0b4ddd2c8ee1993e36245 to your computer and use it in GitHub Desktop.
func prepareBatchProvider() -> MLBatchProvider {
var featureProviders = [MLFeatureProvider]()
var count = 0
errno = 0
let trainFilePath = Bundle.main.url(forResource: "mnist_train", withExtension: "csv")!
if freopen(trainFilePath.path, "r", stdin) == nil {
print("error opening file")
}
while let line = readLine()?.split(separator: ",") {
count += 1
let imageMultiArr = try! MLMultiArray(shape: [1, 28, 28], dataType: .float32)
let outputMultiArr = try! MLMultiArray(shape: [1], dataType: .int32)
for r in 0..<28 {
for c in 0..<28 {
let i = (r*28)+c
imageMultiArr[i] = NSNumber(value: Float(String(line[i + 1]))! / Float(255.0))
}
}
outputMultiArr[0] = NSNumber(value: Int(String(line[0]))!)
let imageValue = MLFeatureValue(multiArray: imageMultiArr)
let outputValue = MLFeatureValue(multiArray: outputMultiArr)
let dataPointFeatures: [String: MLFeatureValue] = ["image": imageValue,
"output_true": outputValue]
if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) {
featureProviders.append(provider)
}
}
return MLArrayBatchProvider(array: featureProviders)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment