Last active
December 11, 2017 04:18
-
-
Save eric-maynard/57e80329e88c6bd30c3fcaf89b49200a to your computer and use it in GitHub Desktop.
A simple example of a cross-validated Random Forest model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml._ | |
import org.apache.spark.ml.tuning._ | |
import org.apache.spark.ml.evaluation._ | |
import org.apache.spark.mllib.regression._ | |
import org.apache.spark.mllib.linalg.DenseVector | |
import org.apache.spark.sql._ | |
import org.apache.spark.ml.classification._ | |
import org.apache.spark.ml.feature._ | |
import sqlContext.implicits._ | |
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics | |
//case classes to use: | |
case class Point(label: Double, features: DenseVector) | |
case class Prediction(actual: Double, estimate: Double) | |
//read in our label data: | |
val parquetDir = "/data/parquet/joinWithLabelTimeT_v2" | |
val rawData = sqlContext.read.parquet(parquetDir) | |
val data = rawData.rdd.map(row => {new Point(row.getInt(34).toDouble, new DenseVector(Array(row.getDouble(0), row.getDouble(1), row.getDouble(2), row.getDouble(3), row.getDouble(4), row.getDouble(5), row.getDouble(6), row.getDouble(7), row.getDouble(8), row.getDouble(9), row.getDouble(10), row.getDouble(11), row.getDouble(12), row.getDouble(13), row.getDouble(14), row.getDouble(15), row.getDouble(16), row.getDouble(17), row.getDouble(18), row.getDouble(19), row.getDouble(20), row.getDouble(21), row.getDouble(22), row.getDouble(23), row.getDouble(24), row.getDouble(25), row.getDouble(26), row.getDouble(27), row.getDouble(28), row.getDouble(29), row.getDouble(30), row.getDouble(31), row.getDouble(32), row.getDouble(33))))}).toDF | |
//prepare training and test data: | |
val Array(training, test) = data.randomSplit(Array(0.8, 0.2)) | |
training.cache() | |
//index the labels: | |
val indexer = new StringIndexer().setInputCol("label").setOutputCol("label_index").fit(training) | |
//create a random forest classifier: | |
val model = new RandomForestClassifier().setLabelCol("label_index").setFeaturesCol("features") | |
val normalizer = new Normalizer() | |
val pipeline = new Pipeline().setStages(Array(indexer, normalizer, model)) | |
//build a parameter grid: | |
val grid = new ParamGridBuilder().addGrid(model.numTrees, Array(5, 15)).addGrid(model.maxBins, Array(32, 64)).addGrid(model.maxDepth, Array(10, 15)).build() | |
//build a cross validator: | |
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(grid).setNumFolds(3) | |
//train the model: | |
val cvModel = cv.fit(training) | |
//calculate more metrics: | |
val bcPredictions = cvModel.transform(test).select("prediction", "label").rdd.map(r => (r.getDouble(0), r.getDouble(1))) | |
val metrics = new BinaryClassificationMetrics(bcPredictions) | |
val auROC = metrics.areaUnderROC | |
println("Area under ROC = " + auROC) | |
val predictions = cvModel.transform(test).select("label", "prediction").rdd.map(r => new Prediction(r.getDouble(0), r.getDouble(1))) | |
+val precision = predictions.filter(p => p.actual == p.estimate).count.toDouble / predictions.count | |
val recall = predictions.filter(p => p.estimate == 1d && p.actual == 1d).count.toDouble / predictions.filter(p => p.actual == 1d).count | |
println(s"Precision\t=\t${precision}\nRecall\t\t=\t${recall}\nF1\t\t=\t${(2 * (precision * recall) / (precision + recall))}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment