Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
using Microsoft.ML;
using Microsoft.ML.Training;
using System;
using System.Collections.Generic;
using System.Linq;
using WineQualityClassification.Helpers;
using WineQualityClassification.ModelNamespace;
namespace WineQualityClassification
{
class Program
{
private static MLContext _mlContext = new MLContext();
private static Dictionary<Model, double> _stats = new Dictionary<Model, double>();
private static string _trainingDataLocation = @"Data/winequality_white_train.csv";
private static string _testDataLocation = @"Data/winequality_white_test.csv";
static void Main(string[] args)
{
var classifiers = new List<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictor>, IPredictor>>()
{
_mlContext.MulticlassClassification.Trainers.LogisticRegression(labelColumn: "Label", featureColumn: "Features"),
_mlContext.MulticlassClassification.Trainers.NaiveBayes(labelColumn: "Label", featureColumn: "Features"),
_mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features")
};
classifiers.ForEach(RunAlgorythm);
var bestModel = _stats.Where(x => x.Value == _stats.Max(y => y.Value)).Single().Key;
VisualizeTenPredictionsForTheModel(bestModel);
bestModel.SaveModel();
Console.ReadLine();
}
private static void RunAlgorythm(ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictor>, IPredictor> algorythm)
{
var model = new Model(_mlContext, algorythm);
model.BuildAndFit(_trainingDataLocation);
PrintAndStoreMetrics(model);
}
private static void PrintAndStoreMetrics(Model model)
{
var metrics = model.Evaluate(_testDataLocation);
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Metrics for {model.Name} ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* R2 Score: {metrics.AccuracyMacro:0.##}");
Console.WriteLine($"* R2 Score: {metrics.AccuracyMicro:0.##}");
Console.WriteLine($"* Absolute loss: {metrics.LogLoss:#.##}");
Console.WriteLine($"* Squared loss: {metrics.LogLossReduction:#.##}");
Console.WriteLine($"* RMS loss: {metrics.LogLossReduction:#.##}");
Console.WriteLine($"*************************************************");
_stats.Add(model, metrics.AccuracyMacro);
}
private static void VisualizeTenPredictionsForTheModel(Model model)
{
Console.WriteLine($"*************************************************");
Console.WriteLine($"* BEST MODEL IS: {model.Name}!");
Console.WriteLine($"* Here are its predictions: ");
var testData = new WineQualityCsvReader().GetDataFromCsv(_testDataLocation).ToList();
for (int i = 0; i < 10; i++)
{
var prediction = model.Predict(testData[i]);
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* Predicted : {prediction.PredictedLabel}");
Console.WriteLine($"* Actual: {testData[i].Label}");
Console.WriteLine($"*------------------------------------------------");
}
Console.WriteLine($"*************************************************");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment