Skip to content

Instantly share code, notes, and snippets.

@eric-maynard
Last active December 11, 2017 04:18
Show Gist options
  • Save eric-maynard/57e80329e88c6bd30c3fcaf89b49200a to your computer and use it in GitHub Desktop.
Save eric-maynard/57e80329e88c6bd30c3fcaf89b49200a to your computer and use it in GitHub Desktop.
A simple example of a cross-validated Random Forest model
import org.apache.spark.ml._
import org.apache.spark.ml.tuning._
import org.apache.spark.ml.evaluation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.sql._
import org.apache.spark.ml.classification._
import org.apache.spark.ml.feature._
import sqlContext.implicits._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
//case classes to use:
case class Point(label: Double, features: DenseVector)
case class Prediction(actual: Double, estimate: Double)
//read in our label data:
val parquetDir = "/data/parquet/joinWithLabelTimeT_v2"
val rawData = sqlContext.read.parquet(parquetDir)
val data = rawData.rdd.map(row => {new Point(row.getInt(34).toDouble, new DenseVector(Array(row.getDouble(0), row.getDouble(1), row.getDouble(2), row.getDouble(3), row.getDouble(4), row.getDouble(5), row.getDouble(6), row.getDouble(7), row.getDouble(8), row.getDouble(9), row.getDouble(10), row.getDouble(11), row.getDouble(12), row.getDouble(13), row.getDouble(14), row.getDouble(15), row.getDouble(16), row.getDouble(17), row.getDouble(18), row.getDouble(19), row.getDouble(20), row.getDouble(21), row.getDouble(22), row.getDouble(23), row.getDouble(24), row.getDouble(25), row.getDouble(26), row.getDouble(27), row.getDouble(28), row.getDouble(29), row.getDouble(30), row.getDouble(31), row.getDouble(32), row.getDouble(33))))}).toDF
//prepare training and test data:
val Array(training, test) = data.randomSplit(Array(0.8, 0.2))
training.cache()
//index the labels:
val indexer = new StringIndexer().setInputCol("label").setOutputCol("label_index").fit(training)
//create a random forest classifier:
val model = new RandomForestClassifier().setLabelCol("label_index").setFeaturesCol("features")
val normalizer = new Normalizer()
val pipeline = new Pipeline().setStages(Array(indexer, normalizer, model))
//build a parameter grid:
val grid = new ParamGridBuilder().addGrid(model.numTrees, Array(5, 15)).addGrid(model.maxBins, Array(32, 64)).addGrid(model.maxDepth, Array(10, 15)).build()
//build a cross validator:
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(grid).setNumFolds(3)
//train the model:
val cvModel = cv.fit(training)
//calculate more metrics:
val bcPredictions = cvModel.transform(test).select("prediction", "label").rdd.map(r => (r.getDouble(0), r.getDouble(1)))
val metrics = new BinaryClassificationMetrics(bcPredictions)
val auROC = metrics.areaUnderROC
println("Area under ROC = " + auROC)
val predictions = cvModel.transform(test).select("label", "prediction").rdd.map(r => new Prediction(r.getDouble(0), r.getDouble(1)))
+val precision = predictions.filter(p => p.actual == p.estimate).count.toDouble / predictions.count
val recall = predictions.filter(p => p.estimate == 1d && p.actual == 1d).count.toDouble / predictions.filter(p => p.actual == 1d).count
println(s"Precision\t=\t${precision}\nRecall\t\t=\t${recall}\nF1\t\t=\t${(2 * (precision * recall) / (precision + recall))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment