Created
April 30, 2019 09:37
-
-
Save AndyButland/3b7f99e363b26a9ceb8e32a1bedba28e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static void Main(string[] args) | |
{ | |
var context = new MLContext(); | |
var data = GetTrainAndTestData(context, InputDataPath); | |
var model = TrainModel(context, data.TrainSet); | |
... | |
} | |
private static ITransformer TrainModel(MLContext context, IDataView data) | |
{ | |
// Construct training pipeline: | |
// - handle missing values | |
// - create a column for the output by copying the one we want to predict to the expected name "Label" | |
// - create a column for all features using the expected name "Features" | |
// - apply a regression | |
var pipeline = context.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: PredictionLabel) | |
.Append(context.Transforms.ReplaceMissingValues( | |
new MissingValueReplacingEstimator.ColumnOptions( | |
"Population", | |
replacementMode: MissingValueReplacingEstimator.ColumnOptions.ReplacementMode.Mean))) | |
.Append(context.Transforms.Concatenate("Features", FeatureColumns)) | |
.Append(context.Regression.Trainers.FastForest()); | |
var model = pipeline.Fit(data); | |
SaveModelAsFile(context, model); | |
return model; | |
} | |
private static void SaveModelAsFile(MLContext context, ITransformer model) | |
{ | |
using (var fileStream = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write)) | |
{ | |
context.Model.Save(model, fileStream); | |
} | |
Console.WriteLine("The model is saved to {0}", ModelPath); | |
Console.WriteLine(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment