Created
April 30, 2019 09:42
-
-
Save AndyButland/bac042ccbc040227e68908d715663c90 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private static void ReportOnFeatureImportance(MLContext context, ITransformer model, IDataView data) | |
{ | |
// Need to cast from the ITransformer interface to gain access to the LastTransformer property. | |
var typedModel = (TransformerChain<RegressionPredictionTransformer<FastForestRegressionModelParameters>>)model; | |
// Calculate metrics. | |
var permutationMetrics = context.Regression.PermutationFeatureImportance(typedModel.LastTransformer, model.Transform(data), PredictionLabel); | |
// Combine metrics with feature names and format for display. | |
var columnsToExclude = new[] { PredictionLabel, "Code", "Name", "IdPreservationColumn" }; | |
var featureNames = data.Schema.AsEnumerable() | |
.Select(column => column.Name) | |
.Where(name => !columnsToExclude.Contains(name)) | |
.ToArray(); | |
var results = featureNames | |
.Select((t, i) => new FeatureImportance | |
{ | |
Name = t, | |
RSquaredMean = Math.Abs(permutationMetrics[i].RSquared.Mean), | |
CorrelationCoefficient = 0 // TBC | |
}) | |
.OrderByDescending(x => x.RSquaredMean); | |
OutputFeatureImportanceResults(results); | |
} | |
private static void OutputFeatureImportanceResults(IEnumerable<FeatureImportance> results) | |
{ | |
Console.WriteLine("Feature importance:"); | |
var table = new ConsoleTable("Feature", "R Squared Mean", "Correlation Coefficient"); | |
foreach (var result in results) | |
{ | |
table.AddRow(result.Name, result.RSquaredMean.ToString("G4"), result.CorrelationCoefficient.ToString("N2")); | |
} | |
table.Write(); | |
Console.WriteLine(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment