Skip to content

Instantly share code, notes, and snippets.

@elbruno
Created September 1, 2020 14:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elbruno/19245363d81fd08db14b6f63d150842d to your computer and use it in GitHub Desktop.
Save elbruno/19245363d81fd08db14b6f63d150842d to your computer and use it in GitHub Desktop.
MLNetAutoMLRanking.cs
using System;
using System.Linq;
using System.Net.Http.Headers;
using Microsoft.ML;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;
namespace ConsoleApp1
{
public class Program
{
static void Main(string[] args)
{
Console.WriteLine("Start ...");
Run();
Console.WriteLine("End");
}
private static string TrainDataPath = @"data\train.txt";
private static string TestDataPath = @"data\test.txt";
private static string ModelPath = @"Model.zip";
private static string LabelColumnName = "Label";
private static string GroupColumnName = "GroupId";
private static uint ExperimentTime = 600;
public static void Run()
{
var mlContext = new MLContext();
// STEP 1: Load data
var trainDataView = mlContext.Data.LoadFromTextFile<SearchData>(TrainDataPath, hasHeader: false, separatorChar: '\t');
var testDataView = mlContext.Data.LoadFromTextFile<SearchData>(TestDataPath, hasHeader: false, separatorChar: '\t');
// STEP 2: Run AutoML experiment
Console.WriteLine($"Running AutoML recommendation experiment for {ExperimentTime} seconds...");
var experimentResult = mlContext.Auto()
.CreateRankingExperiment(new RankingExperimentSettings() { MaxExperimentTimeInSeconds = ExperimentTime })
.Execute(trainDataView, testDataView,
new ColumnInformation()
{
LabelColumnName = LabelColumnName,
GroupIdColumnName = GroupColumnName
});
// STEP 3: Print metric from best model
var bestRun = experimentResult.BestRun;
Console.WriteLine($"=====================================================");
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}");
var i = 0;
foreach (var experimentResultRunDetail in experimentResult.RunDetails)
{
i++;
Console.WriteLine($" {i} - TrainerName: {experimentResultRunDetail.TrainerName}");
Console.WriteLine($" Runtime In Seconds: {experimentResultRunDetail.RuntimeInSeconds}");
Console.WriteLine($"");
//PrintMetrics(experimentResultRunDetail.ValidationMetrics);
}
Console.WriteLine($"");
Console.WriteLine($"=====================================================");
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}");
// STEP 5: Evaluate test data
var testDataViewWithBestScore = bestRun.Model.Transform(testDataView);
var testMetrics = mlContext.Ranking.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName);
Console.WriteLine($"Metrics of best model on test data --");
PrintMetrics(testMetrics);
// STEP 6: Save the best model for later deployment and inferencing
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, ModelPath);
// STEP 7: Create prediction engine from the best trained model
var predictionEngine = mlContext.Model.CreatePredictionEngine<SearchData, SearchDataPrediction>(bestRun.Model);
// STEP 8: Initialize a new test, and get the prediction
var testPage = new SearchData
{
GroupId = "1",
Features = 9,
Label = 1
};
var prediction = predictionEngine.Predict(testPage);
Console.WriteLine($"Predicted rating for: {prediction.Prediction}");
// New Page
testPage = new SearchData
{
GroupId = "2",
Features = 2,
Label = 9
};
prediction = predictionEngine.Predict(testPage);
Console.WriteLine($"Predicted: {prediction.Prediction}");
Console.WriteLine("Press any key to continue...");
Console.ReadKey();
}
private static void PrintMetrics(RankingMetrics metrics)
{
if (metrics is null)
{
Console.WriteLine($" No metrics");
return;
}
var ndcg = metrics.NormalizedDiscountedCumulativeGains.Aggregate("", (current, p) => current + p + " - ");
var dcg = metrics.DiscountedCumulativeGains.Aggregate("", (current, p) => current + p + " - ");
Console.WriteLine($" Normalized Discounted Cumulative Gains: {ndcg}");
Console.WriteLine($" Discounted Cumulative Gains: {dcg}");
}
}
class SearchData
{
[LoadColumn(0)]
public string GroupId;
[LoadColumn(1)]
public float Features;
[LoadColumn(2)]
public float Label;
}
class SearchDataPrediction
{
[ColumnName("PredictedLabel")]
public float Prediction;
public float Score { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment