Created
August 22, 2023 20:06
-
-
Save dnutiu/007c25bd42d6f3037078ed88aee41537 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import Vision | |
print("Hello, World!") | |
let imagePath = "/Users/dnutiu/Pictures/Jul 15 2023/540A6766.jpg" | |
let config = MLModelConfiguration() | |
let imageTagger = try resnet151(configuration: config) | |
// Get the underlying model instance. | |
let imageTaggerModel = imageTagger.model | |
// Create a Vision instance using the image classifier's model instance. | |
guard let imageClassifierVisionModel = try? VNCoreMLModel(for: imageTaggerModel) else { | |
fatalError("App failed to create a `VNCoreMLModel` instance.") | |
} | |
func printResultsInConsole(results: [(String, Float)]) { | |
print() | |
results.forEach { result in | |
print(result.0) | |
} | |
print() | |
} | |
/// Completion handler for vision requests. | |
func visionRequestHandler(request: VNRequest, error: Error?) { | |
if let observations: [VNObservation] = request.results { | |
// Filter values with confidence < -0.5 | |
let filtered = observations.filter({ ($0 as? VNClassificationObservation)?.confidence ?? 0 > -0.5 }) | |
// Format results by looking for the `VNClassificationObservation` type. | |
let results = filtered.compactMap({ (observation: VNObservation) -> (String, Float) in | |
if let clasificationObservation = observation as? VNClassificationObservation { | |
return (clasificationObservation.identifier, clasificationObservation.confidence) | |
} else { | |
fatalError("Invalid observation") | |
} | |
}) | |
printResultsInConsole(results: results) | |
} else { | |
fatalError("Unable to read observations") | |
} | |
} | |
// Create an image classification request | |
let imageClassificationRequest = VNCoreMLRequest(model: imageClassifierVisionModel, completionHandler: visionRequestHandler) | |
imageClassificationRequest.imageCropAndScaleOption = .centerCrop | |
// Create a request handler. | |
let imageRequestHandler = VNImageRequestHandler(url: URL(fileURLWithPath: imagePath)) | |
// Perform the classification request. | |
try? imageRequestHandler.perform([imageClassificationRequest]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment