Skip to content

Instantly share code, notes, and snippets.

@elbruno
Created September 5, 2018 21:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elbruno/7034f78c6eaceafc5936c139830841ee to your computer and use it in GitHub Desktop.
Save elbruno/7034f78c6eaceafc5936c139830841ee to your computer and use it in GitHub Desktop.
MLNetTensorFlow.cs
using System;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
namespace MLNetConsole10
{
class Program
{
static void Main(string[] args)
{
const int imageHeight = 32;
const int imageWidth = 32;
const string modelLocation = "cifar_model/frozen_model.pb";
const string dataFile = "images/images.tsv";
var imageFolder = Path.GetDirectoryName(dataFile);
var pipeline = new LearningPipeline();
pipeline.Add(new Microsoft.ML.Data.TextLoader(dataFile).CreateFrom<CifarData>());
pipeline.Add(new ImageLoader(("ImagePath", "ImageReal"))
{
ImageFolder = imageFolder
});
pipeline.Add(new ImageResizer(("ImageReal", "ImageCropped"))
{
ImageHeight = imageHeight,
ImageWidth = imageWidth,
Resizing = ImageResizerTransformResizingKind.IsoCrop
});
pipeline.Add(new ImagePixelExtractor(("ImageCropped", "Input"))
{
UseAlpha = false,
InterleaveArgb = true
});
pipeline.Add(new TensorFlowScorer()
{
ModelFile = modelLocation,
InputColumns = new[] { "Input" },
OutputColumn = "Output"
});
pipeline.Add(new ColumnConcatenator("Features", "Output"));
pipeline.Add(new TextToKeyConverter("Label"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
var model = pipeline.Train<CifarData, CifarPrediction>();
model.TryGetScoreLabelNames(out var scoreLabels);
Console.WriteLine($"ScoreLabels.Length {scoreLabels.Length}");
Console.WriteLine($"banana {scoreLabels[0]}");
Console.WriteLine($"hotdog {scoreLabels[1]}");
Console.WriteLine($"tomato {scoreLabels[2]}");
var prediction = model.Predict(new CifarData()
{
ImagePath = "images/banana.jpg"
});
Console.WriteLine($"{prediction.PredictedLabels[0]}");
Console.WriteLine($"{prediction.PredictedLabels[1]}");
Console.WriteLine($"{prediction.PredictedLabels[2]}");
}
}
public class CifarData
{
[Column("0")]
public string ImagePath;
[Column("1")]
public string Label;
}
public class CifarPrediction
{
[ColumnName("Score")]
public float[] PredictedLabels;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment