Created
September 1, 2020 14:30
-
-
Save elbruno/19245363d81fd08db14b6f63d150842d to your computer and use it in GitHub Desktop.
MLNetAutoMLRanking.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Linq; | |
using System.Net.Http.Headers; | |
using Microsoft.ML; | |
using Microsoft.ML.AutoML; | |
using Microsoft.ML.Data; | |
namespace ConsoleApp1 | |
{ | |
public class Program | |
{ | |
static void Main(string[] args) | |
{ | |
Console.WriteLine("Start ..."); | |
Run(); | |
Console.WriteLine("End"); | |
} | |
private static string TrainDataPath = @"data\train.txt"; | |
private static string TestDataPath = @"data\test.txt"; | |
private static string ModelPath = @"Model.zip"; | |
private static string LabelColumnName = "Label"; | |
private static string GroupColumnName = "GroupId"; | |
private static uint ExperimentTime = 600; | |
public static void Run() | |
{ | |
var mlContext = new MLContext(); | |
// STEP 1: Load data | |
var trainDataView = mlContext.Data.LoadFromTextFile<SearchData>(TrainDataPath, hasHeader: false, separatorChar: '\t'); | |
var testDataView = mlContext.Data.LoadFromTextFile<SearchData>(TestDataPath, hasHeader: false, separatorChar: '\t'); | |
// STEP 2: Run AutoML experiment | |
Console.WriteLine($"Running AutoML recommendation experiment for {ExperimentTime} seconds..."); | |
var experimentResult = mlContext.Auto() | |
.CreateRankingExperiment(new RankingExperimentSettings() { MaxExperimentTimeInSeconds = ExperimentTime }) | |
.Execute(trainDataView, testDataView, | |
new ColumnInformation() | |
{ | |
LabelColumnName = LabelColumnName, | |
GroupIdColumnName = GroupColumnName | |
}); | |
// STEP 3: Print metric from best model | |
var bestRun = experimentResult.BestRun; | |
Console.WriteLine($"====================================================="); | |
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); | |
var i = 0; | |
foreach (var experimentResultRunDetail in experimentResult.RunDetails) | |
{ | |
i++; | |
Console.WriteLine($" {i} - TrainerName: {experimentResultRunDetail.TrainerName}"); | |
Console.WriteLine($" Runtime In Seconds: {experimentResultRunDetail.RuntimeInSeconds}"); | |
Console.WriteLine($""); | |
//PrintMetrics(experimentResultRunDetail.ValidationMetrics); | |
} | |
Console.WriteLine($""); | |
Console.WriteLine($"====================================================="); | |
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); | |
// STEP 5: Evaluate test data | |
var testDataViewWithBestScore = bestRun.Model.Transform(testDataView); | |
var testMetrics = mlContext.Ranking.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); | |
Console.WriteLine($"Metrics of best model on test data --"); | |
PrintMetrics(testMetrics); | |
// STEP 6: Save the best model for later deployment and inferencing | |
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, ModelPath); | |
// STEP 7: Create prediction engine from the best trained model | |
var predictionEngine = mlContext.Model.CreatePredictionEngine<SearchData, SearchDataPrediction>(bestRun.Model); | |
// STEP 8: Initialize a new test, and get the prediction | |
var testPage = new SearchData | |
{ | |
GroupId = "1", | |
Features = 9, | |
Label = 1 | |
}; | |
var prediction = predictionEngine.Predict(testPage); | |
Console.WriteLine($"Predicted rating for: {prediction.Prediction}"); | |
// New Page | |
testPage = new SearchData | |
{ | |
GroupId = "2", | |
Features = 2, | |
Label = 9 | |
}; | |
prediction = predictionEngine.Predict(testPage); | |
Console.WriteLine($"Predicted: {prediction.Prediction}"); | |
Console.WriteLine("Press any key to continue..."); | |
Console.ReadKey(); | |
} | |
private static void PrintMetrics(RankingMetrics metrics) | |
{ | |
if (metrics is null) | |
{ | |
Console.WriteLine($" No metrics"); | |
return; | |
} | |
var ndcg = metrics.NormalizedDiscountedCumulativeGains.Aggregate("", (current, p) => current + p + " - "); | |
var dcg = metrics.DiscountedCumulativeGains.Aggregate("", (current, p) => current + p + " - "); | |
Console.WriteLine($" Normalized Discounted Cumulative Gains: {ndcg}"); | |
Console.WriteLine($" Discounted Cumulative Gains: {dcg}"); | |
} | |
} | |
class SearchData | |
{ | |
[LoadColumn(0)] | |
public string GroupId; | |
[LoadColumn(1)] | |
public float Features; | |
[LoadColumn(2)] | |
public float Label; | |
} | |
class SearchDataPrediction | |
{ | |
[ColumnName("PredictedLabel")] | |
public float Prediction; | |
public float Score { get; set; } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment