internal static class Program | |
{ | |
private static PredictionModel<SentimentData, SentimentPrediction> _model; | |
private static PredictionModel<SentimentData, SentimentPrediction> _modelWordEmbeddings; | |
private static string AppPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]); | |
private static string TrainDataPath => Path.Combine(AppPath, "datasets", "sentiment-imdb-train.txt"); | |
private static string TestDataPath => Path.Combine(AppPath, "datasets", "sentiment-yelp-test.txt"); | |
private static string ModelPath => Path.Combine(AppPath, "SentimentModel.zip"); | |
private static void Main(string[] args) | |
{ | |
TrainModel(); | |
TrainModelWordEmbeddings(); | |
Evaluate(_model, "normal"); | |
Evaluate(_modelWordEmbeddings, "using WordEmbeddings"); | |
Console.ReadLine(); | |
} | |
public static void TrainModel() | |
{ | |
var pipeline = new LearningPipeline(); | |
pipeline.Add(new TextLoader(TrainDataPath).CreateFrom<SentimentData>()); | |
pipeline.Add(new TextFeaturizer("Features", "SentimentText")); | |
pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); | |
Console.WriteLine("=============== Training model ==============="); | |
var model = pipeline.Train<SentimentData, SentimentPrediction>(); | |
Console.WriteLine("=============== End training ==============="); | |
_model = model; | |
} | |
public static void TrainModelWordEmbeddings() | |
{ | |
var pipeline = new LearningPipeline(); | |
pipeline.Add(new TextLoader(TrainDataPath).CreateFrom<SentimentData>()); | |
pipeline.Add(new TextFeaturizer("FeaturesA", "SentimentText") { OutputTokens = true }); | |
pipeline.Add(new WordEmbeddings(("FeaturesA_TransformedText", "FeaturesB"))); | |
pipeline.Add(new ColumnConcatenator("Features", "FeaturesA", "FeaturesB")); | |
pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); | |
Console.WriteLine("=============== Training model with Word Embeddings ==============="); | |
var model = pipeline.Train<SentimentData, SentimentPrediction>(); | |
Console.WriteLine("=============== End training ==============="); | |
_modelWordEmbeddings = model; | |
} | |
private static void Evaluate(PredictionModel<SentimentData, SentimentPrediction> model, string name) | |
{ | |
var testData = new TextLoader(TestDataPath).CreateFrom<SentimentData>(); | |
var evaluator = new BinaryClassificationEvaluator(); | |
Console.WriteLine("=============== Evaluating model {0} ===============", name); | |
var metrics = evaluator.Evaluate(model, testData); | |
Console.WriteLine($"Accuracy: {metrics.Accuracy:P2}"); | |
Console.WriteLine($"Auc: {metrics.Auc:P2}"); | |
Console.WriteLine($"F1Score: {metrics.F1Score:P2}"); | |
Console.WriteLine("=============== End evaluating ==============="); | |
Console.WriteLine(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment