Created
July 28, 2018 18:00
-
-
Save AbhiOnGithub/e0ba411ad7c74275cea37b4e0f3d4440 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.IO; | |
using System.Threading.Tasks; | |
using Microsoft.ML; | |
using Microsoft.ML.Data; | |
using Microsoft.ML.Models; | |
using Microsoft.ML.Trainers; | |
using Microsoft.ML.Transforms; | |
namespace FarePredictor | |
{ | |
class Program | |
{ | |
static readonly string _datapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-train.csv"); | |
static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-test.csv"); | |
static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip"); | |
static async Task Main(string[] args) | |
{ | |
Console.WriteLine("Building Fare Predictor!"); | |
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = await TrainAsync(); | |
Evaluate(model); | |
TaxiTripFarePrediction prediction = model.Predict(TestTrips.Trip1); | |
Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount); | |
} | |
public static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> TrainAsync() | |
{ | |
var pipeline = new LearningPipeline | |
{ | |
new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','), | |
new ColumnCopier(("FareAmount", "Label")), | |
new CategoricalOneHotVectorizer( | |
"VendorId", | |
"RateCode", | |
"PaymentType"), | |
new ColumnConcatenator( | |
"Features", | |
"VendorId", | |
"RateCode", | |
"PassengerCount", | |
"TripDistance", | |
"PaymentType"), | |
new FastTreeRegressor() | |
}; | |
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>(); | |
await model.WriteAsync(_modelpath); | |
return model; | |
} | |
private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model) | |
{ | |
var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','); | |
var evaluator = new RegressionEvaluator(); | |
RegressionMetrics metrics = evaluator.Evaluate(model, testData); | |
Console.WriteLine($"Rms = {metrics.Rms}"); | |
Console.WriteLine($"RSquared = {metrics.RSquared}"); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment