MLNetTensorFlow.cs
using System; | |
using System.IO; | |
using Microsoft.ML; | |
using Microsoft.ML.Runtime.Api; | |
using Microsoft.ML.Trainers; | |
using Microsoft.ML.Transforms; | |
namespace MLNetConsole10 | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
const int imageHeight = 32; | |
const int imageWidth = 32; | |
const string modelLocation = "cifar_model/frozen_model.pb"; | |
const string dataFile = "images/images.tsv"; | |
var imageFolder = Path.GetDirectoryName(dataFile); | |
var pipeline = new LearningPipeline(); | |
pipeline.Add(new Microsoft.ML.Data.TextLoader(dataFile).CreateFrom<CifarData>()); | |
pipeline.Add(new ImageLoader(("ImagePath", "ImageReal")) | |
{ | |
ImageFolder = imageFolder | |
}); | |
pipeline.Add(new ImageResizer(("ImageReal", "ImageCropped")) | |
{ | |
ImageHeight = imageHeight, | |
ImageWidth = imageWidth, | |
Resizing = ImageResizerTransformResizingKind.IsoCrop | |
}); | |
pipeline.Add(new ImagePixelExtractor(("ImageCropped", "Input")) | |
{ | |
UseAlpha = false, | |
InterleaveArgb = true | |
}); | |
pipeline.Add(new TensorFlowScorer() | |
{ | |
ModelFile = modelLocation, | |
InputColumns = new[] { "Input" }, | |
OutputColumn = "Output" | |
}); | |
pipeline.Add(new ColumnConcatenator("Features", "Output")); | |
pipeline.Add(new TextToKeyConverter("Label")); | |
pipeline.Add(new StochasticDualCoordinateAscentClassifier()); | |
var model = pipeline.Train<CifarData, CifarPrediction>(); | |
model.TryGetScoreLabelNames(out var scoreLabels); | |
Console.WriteLine($"ScoreLabels.Length {scoreLabels.Length}"); | |
Console.WriteLine($"banana {scoreLabels[0]}"); | |
Console.WriteLine($"hotdog {scoreLabels[1]}"); | |
Console.WriteLine($"tomato {scoreLabels[2]}"); | |
var prediction = model.Predict(new CifarData() | |
{ | |
ImagePath = "images/banana.jpg" | |
}); | |
Console.WriteLine($"{prediction.PredictedLabels[0]}"); | |
Console.WriteLine($"{prediction.PredictedLabels[1]}"); | |
Console.WriteLine($"{prediction.PredictedLabels[2]}"); | |
} | |
} | |
public class CifarData | |
{ | |
[Column("0")] | |
public string ImagePath; | |
[Column("1")] | |
public string Label; | |
} | |
public class CifarPrediction | |
{ | |
[ColumnName("Score")] | |
public float[] PredictedLabels; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment