Created
September 5, 2018 21:30
-
-
Save elbruno/7034f78c6eaceafc5936c139830841ee to your computer and use it in GitHub Desktop.
MLNetTensorFlow.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.IO; | |
using Microsoft.ML; | |
using Microsoft.ML.Runtime.Api; | |
using Microsoft.ML.Trainers; | |
using Microsoft.ML.Transforms; | |
namespace MLNetConsole10 | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
const int imageHeight = 32; | |
const int imageWidth = 32; | |
const string modelLocation = "cifar_model/frozen_model.pb"; | |
const string dataFile = "images/images.tsv"; | |
var imageFolder = Path.GetDirectoryName(dataFile); | |
var pipeline = new LearningPipeline(); | |
pipeline.Add(new Microsoft.ML.Data.TextLoader(dataFile).CreateFrom<CifarData>()); | |
pipeline.Add(new ImageLoader(("ImagePath", "ImageReal")) | |
{ | |
ImageFolder = imageFolder | |
}); | |
pipeline.Add(new ImageResizer(("ImageReal", "ImageCropped")) | |
{ | |
ImageHeight = imageHeight, | |
ImageWidth = imageWidth, | |
Resizing = ImageResizerTransformResizingKind.IsoCrop | |
}); | |
pipeline.Add(new ImagePixelExtractor(("ImageCropped", "Input")) | |
{ | |
UseAlpha = false, | |
InterleaveArgb = true | |
}); | |
pipeline.Add(new TensorFlowScorer() | |
{ | |
ModelFile = modelLocation, | |
InputColumns = new[] { "Input" }, | |
OutputColumn = "Output" | |
}); | |
pipeline.Add(new ColumnConcatenator("Features", "Output")); | |
pipeline.Add(new TextToKeyConverter("Label")); | |
pipeline.Add(new StochasticDualCoordinateAscentClassifier()); | |
var model = pipeline.Train<CifarData, CifarPrediction>(); | |
model.TryGetScoreLabelNames(out var scoreLabels); | |
Console.WriteLine($"ScoreLabels.Length {scoreLabels.Length}"); | |
Console.WriteLine($"banana {scoreLabels[0]}"); | |
Console.WriteLine($"hotdog {scoreLabels[1]}"); | |
Console.WriteLine($"tomato {scoreLabels[2]}"); | |
var prediction = model.Predict(new CifarData() | |
{ | |
ImagePath = "images/banana.jpg" | |
}); | |
Console.WriteLine($"{prediction.PredictedLabels[0]}"); | |
Console.WriteLine($"{prediction.PredictedLabels[1]}"); | |
Console.WriteLine($"{prediction.PredictedLabels[2]}"); | |
} | |
} | |
public class CifarData | |
{ | |
[Column("0")] | |
public string ImagePath; | |
[Column("1")] | |
public string Label; | |
} | |
public class CifarPrediction | |
{ | |
[ColumnName("Score")] | |
public float[] PredictedLabels; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment