Skip to content

Instantly share code, notes, and snippets.

@AbhiOnGithub
Created July 28, 2018 19:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AbhiOnGithub/5dae94f4bedfeca418ec3769eecfd5a4 to your computer and use it in GitHub Desktop.
Save AbhiOnGithub/5dae94f4bedfeca418ec3769eecfd5a4 to your computer and use it in GitHub Desktop.
using System;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
namespace SentimentAnalysis
{
static class Program
{
private static string AppPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]);
private static string TrainDataPath => Path.Combine(AppPath,"sentiment-imdb-train.txt");
private static string TestDataPath => Path.Combine(AppPath,"sentiment-yelp-test.txt");
private static string ModelPath => Path.Combine(AppPath, "SentimentModel.zip");
private static async Task Main(string[] args)
{
// STEP 1: Create a model
var model = await TrainAsync();
// STEP2: Test accuracy
Evaluate(model);
// STEP 3: Make a prediction
var predictions = model.Predict(TestSentimentData.Sentiments);
var sentimentsAndPredictions =
TestSentimentData.Sentiments.Zip(predictions, (sentiment, prediction) => (sentiment, prediction));
foreach (var item in sentimentsAndPredictions)
{
Console.WriteLine(
$"Sentiment: {item.sentiment.SentimentText} | Prediction: {(item.prediction.Sentiment ? "Positive" : "Negative")} sentiment");
}
Console.ReadLine();
}
public static async Task<PredictionModel<SentimentData, SentimentPrediction>> TrainAsync()
{
// LearningPipeline holds all steps of the learning process: data, transforms, learners.
var pipeline = new LearningPipeline();
// The TextLoader loads a dataset. The schema of the dataset is specified by passing a class containing
// all the column names and their types.
pipeline.Add(new TextLoader(TrainDataPath).CreateFrom<SentimentData>());
// TextFeaturizer is a transform that will be used to featurize an input column to format and clean the data.
pipeline.Add(new TextFeaturizer("Features", "SentimentText"));
// FastTreeBinaryClassifier is an algorithm that will be used to train the model.
// It has three hyperparameters for tuning decision tree performance.
pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 });
Console.WriteLine("=============== Training model ===============");
// The pipeline is trained on the dataset that has been loaded and transformed.
var model = pipeline.Train<SentimentData, SentimentPrediction>();
// Saving the model as a .zip file.
await model.WriteAsync(ModelPath);
Console.WriteLine("=============== End training ===============");
Console.WriteLine("The model is saved to {0}", ModelPath);
return model;
}
private static void Evaluate(PredictionModel<SentimentData, SentimentPrediction> model)
{
// To evaluate how good the model predicts values, the model is ran against new set
// of data (test data) that was not involved in training.
var testData = new TextLoader(TestDataPath).CreateFrom<SentimentData>();
// BinaryClassificationEvaluator performs evaluation for Binary Classification type of ML problems.
var evaluator = new BinaryClassificationEvaluator();
Console.WriteLine("=============== Evaluating model ===============");
var metrics = evaluator.Evaluate(model, testData);
// BinaryClassificationMetrics contains the overall metrics computed by binary classification evaluators
// The Accuracy metric gets the accuracy of a classifier which is the proportion
//of correct predictions in the test set.
// The Auc metric gets the area under the ROC curve.
// The area under the ROC curve is equal to the probability that the classifier ranks
// a randomly chosen positive instance higher than a randomly chosen negative one
// (assuming 'positive' ranks higher than 'negative').
// The F1Score metric gets the classifier's F1 score.
// The F1 score is the harmonic mean of precision and recall:
// 2 * precision * recall / (precision + recall).
Console.WriteLine($"Accuracy: {metrics.Accuracy:P2}");
Console.WriteLine($"Auc: {metrics.Auc:P2}");
Console.WriteLine($"F1Score: {metrics.F1Score:P2}");
Console.WriteLine("=============== End evaluating ===============");
Console.WriteLine();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment