Skip to content

Instantly share code, notes, and snippets.

@AbhiOnGithub
Created July 28, 2018 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AbhiOnGithub/e0ba411ad7c74275cea37b4e0f3d4440 to your computer and use it in GitHub Desktop.
Save AbhiOnGithub/e0ba411ad7c74275cea37b4e0f3d4440 to your computer and use it in GitHub Desktop.
using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
namespace FarePredictor
{
class Program
{
static readonly string _datapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-train.csv");
static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-test.csv");
static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip");
static async Task Main(string[] args)
{
Console.WriteLine("Building Fare Predictor!");
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = await TrainAsync();
Evaluate(model);
TaxiTripFarePrediction prediction = model.Predict(TestTrips.Trip1);
Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount);
}
public static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> TrainAsync()
{
var pipeline = new LearningPipeline
{
new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','),
new ColumnCopier(("FareAmount", "Label")),
new CategoricalOneHotVectorizer(
"VendorId",
"RateCode",
"PaymentType"),
new ColumnConcatenator(
"Features",
"VendorId",
"RateCode",
"PassengerCount",
"TripDistance",
"PaymentType"),
new FastTreeRegressor()
};
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
await model.WriteAsync(_modelpath);
return model;
}
private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
{
var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');
var evaluator = new RegressionEvaluator();
RegressionMetrics metrics = evaluator.Evaluate(model, testData);
Console.WriteLine($"Rms = {metrics.Rms}");
Console.WriteLine($"RSquared = {metrics.RSquared}");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment