Skip to content

Instantly share code, notes, and snippets.

@lqdev
Created December 14, 2020 01:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lqdev/0c4dc9eea93b7b8541f31ddd429afb53 to your computer and use it in GitHub Desktop.
Save lqdev/0c4dc9eea93b7b8541f31ddd429afb53 to your computer and use it in GitHub Desktop.
Code from F# Advent 2020 Image Classifier
#r "nuget:Microsoft.ML"
#r "nuget:Microsoft.ML.Vision"
#r "nuget:Microsoft.ML.ImageAnalytics"
#r "nuget:SciSharp.TensorFlow.Redist"
open System
open System.IO
open Microsoft.ML
open Microsoft.ML.Data
open Microsoft.ML.Vision
[<CLIMutable>]
type ImageData = {
ImagePath: string
Label: string
}
[<CLIMutable>]
type ImagePrediction = {
ImagePath: string
PredictedLabel: string
}
// Load images from directory and assign directory name as labels
let loadImagesFromDirectory (path:string) (useDirectoryAsLabel:bool) =
let files = Directory.GetFiles(path, "*",searchOption=SearchOption.AllDirectories)
files
|> Array.filter(fun file ->
(Path.GetExtension(file) = ".jpg") ||
(Path.GetExtension(file) = ".png"))
|> Array.map(fun file ->
let mutable label = Path.GetFileName(file)
if useDirectoryAsLabel then
label <- Directory.GetParent(file).Name
else
let mutable brk = false
for index in 0..label.Length do
while not brk do
if not (label.[index] |> Char.IsLetter) then
label <- label.Substring(0,index)
brk <- true
{ImagePath=file; Label=label}
)
// Initialize MLContext
let ctx = MLContext()
// Load images
let imageData = loadImagesFromDirectory "C:/Datasets/fsadvent2020/Train" true
// Createa an IDataView for the images.
let imageIdv = ctx.Data.LoadFromEnumerable<ImageData>(imageData)
// Set image classifier options
let classifierOptions = ImageClassificationTrainer.Options()
classifierOptions.FeatureColumnName <- "Image"
classifierOptions.LabelColumnName <- "LabelKey"
classifierOptions.TestOnTrainSet <- true
classifierOptions.Arch <- ImageClassificationTrainer.Architecture.ResnetV2101
classifierOptions.MetricsCallback <- Action<ImageClassificationTrainer.ImageClassificationMetrics>(fun x -> printfn "%s" (x.ToString()))
// Define training / consumption pipeline
let pipeline =
EstimatorChain()
.Append(ctx.Transforms.LoadRawImageBytes("Image",null,"ImagePath"))
.Append(ctx.Transforms.Conversion.MapValueToKey("LabelKey","Label"))
.Append(ctx.MulticlassClassification.Trainers.ImageClassification(classifierOptions))
.Append(ctx.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
// Train the model
let model = pipeline.Fit(imageIdv)
// (Optional) Save the model
ctx.Model.Save(model,imageIdv.Schema,"fsadvent2020-model.zip")
// Load the model
//let (model,schema) = ctx.Model.Load("fsadvent2020-model.zip")
// Load test images
let testImages =
Directory.GetFiles("C:/Datasets/fsadvent2020/Test")
|> Array.map(fun file -> {ImagePath=file; Label=""})
// Create IDataView for test images
let testImageIdv = ctx.Data.LoadFromEnumerable<ImageData>(testImages)
// Make predictions
let predictionIdv = model.Transform(testImageIdv)
// Display predictions
let predictions = ctx.Data.CreateEnumerable<ImagePrediction>(predictionIdv,false)
predictions |> Seq.iter(fun pred ->
printfn "%s is %s" (Path.GetFileNameWithoutExtension(pred.ImagePath)) pred.PredictedLabel)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment