Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save JacopoMangiavacchi/13c3f9d6354cb121dbfd2e7a75767c88 to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/13c3f9d6354cb121dbfd2e7a75767c88 to your computer and use it in GitHub Desktop.
import Foundation
import CoreML
func generateData(sampleSize: Int = 100) -> ([Float], [Float]) {
let a: Float = 2.0
let b: Float = 1.5
var X = [Float]()
var Y = [Float]()
for i in 0..<sampleSize {
let x: Float = Float(i) / Float(sampleSize)
let noise: Float = (Float.random(in: 0..<1) - 0.5) * 0.1
let y: Float = (a * x + b) + noise
X.append(x)
Y.append(y)
}
return (X, Y)
}
func prepareTrainingBatch() -> MLBatchProvider {
var featureProviders = [MLFeatureProvider]()
let inputName = "dense_input"
let outputName = "output_true"
let (X, Y) = generateData()
for (x,y) in zip(X, Y) {
let multiArr = try! MLMultiArray(shape: [1], dataType: .double)
multiArr[0] = NSNumber(value: x)
let inputValue = MLFeatureValue(multiArray: multiArr)
multiArr[0] = NSNumber(value: y)
let outputValue = MLFeatureValue(multiArray: multiArr)
let dataPointFeatures: [String: MLFeatureValue] = [inputName: inputValue,
outputName: outputValue]
if let provider = try? MLDictionaryFeatureProvider(dictionary: dataPointFeatures) {
featureProviders.append(provider)
}
}
return MLArrayBatchProvider(array: featureProviders)
}
func train(url: URL) {
let configuration = MLModelConfiguration()
configuration.computeUnits = .all
configuration.parameters = [.epochs : 100]
let progressHandler = { (context: MLUpdateContext) in
switch context.event {
case .trainingBegin:
print("Training begin")
case .miniBatchEnd:
let batchIndex = context.metrics[.miniBatchIndex] as! Int
let batchLoss = context.metrics[.lossValue] as! Double
print("Mini batch \(batchIndex), loss: \(batchLoss)")
case .epochEnd:
let epochIndex = context.metrics[.epochIndex] as! Int
let trainLoss = context.metrics[.lossValue] as! Double
print("Epoch \(epochIndex) end with loss \(trainLoss)")
default:
print("Unknown event")
}
}
let completionHandler = { (context: MLUpdateContext) in
print("Training completed with state \(context.task.state.rawValue)")
print("CoreML Error: \(context.task.error.debugDescription)")
if context.task.state != .completed {
print("Failed")
return
}
let trainLoss = context.metrics[.lossValue] as! Double
print("Final loss: \(trainLoss)")
let updatedModel = context.model
let updatedModelURL = URL(fileURLWithPath: retrainedCoreMLFilePath)
try! updatedModel.write(to: updatedModelURL)
print("Model Trained!")
print("Press return to continue..")
}
let handlers = MLUpdateProgressHandlers(
forEvents: [.trainingBegin, .miniBatchEnd, .epochEnd],
progressHandler: progressHandler,
completionHandler: completionHandler)
let updateTask = try! MLUpdateTask(forModelAt: url,
trainingData: prepareTrainingBatch(),
configuration: configuration,
progressHandlers: handlers)
updateTask.resume()
}
train(url: compiledModelUrl)
// easily wait for completition of the asyncronous training task
let _ = readLine()
let retrainedModel = try! MLModel(contentsOf: URL(fileURLWithPath: retrainedCoreMLFilePath))
let prediction = inferenceCoreML(model: retrainedModel, x: 1.0)
print(prediction)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment