Skip to content

Instantly share code, notes, and snippets.

@AndyButland
Created April 30, 2019 09:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AndyButland/bac042ccbc040227e68908d715663c90 to your computer and use it in GitHub Desktop.
Save AndyButland/bac042ccbc040227e68908d715663c90 to your computer and use it in GitHub Desktop.
private static void ReportOnFeatureImportance(MLContext context, ITransformer model, IDataView data)
{
// Need to cast from the ITransformer interface to gain access to the LastTransformer property.
var typedModel = (TransformerChain<RegressionPredictionTransformer<FastForestRegressionModelParameters>>)model;
// Calculate metrics.
var permutationMetrics = context.Regression.PermutationFeatureImportance(typedModel.LastTransformer, model.Transform(data), PredictionLabel);
// Combine metrics with feature names and format for display.
var columnsToExclude = new[] { PredictionLabel, "Code", "Name", "IdPreservationColumn" };
var featureNames = data.Schema.AsEnumerable()
.Select(column => column.Name)
.Where(name => !columnsToExclude.Contains(name))
.ToArray();
var results = featureNames
.Select((t, i) => new FeatureImportance
{
Name = t,
RSquaredMean = Math.Abs(permutationMetrics[i].RSquared.Mean),
CorrelationCoefficient = 0 // TBC
})
.OrderByDescending(x => x.RSquaredMean);
OutputFeatureImportanceResults(results);
}
private static void OutputFeatureImportanceResults(IEnumerable<FeatureImportance> results)
{
Console.WriteLine("Feature importance:");
var table = new ConsoleTable("Feature", "R Squared Mean", "Correlation Coefficient");
foreach (var result in results)
{
table.AddRow(result.Name, result.RSquaredMean.ToString("G4"), result.CorrelationCoefficient.ToString("N2"));
}
table.Write();
Console.WriteLine();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment