Skip to content

Instantly share code, notes, and snippets.

@szilard
Last active August 29, 2015 14:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szilard/b9e1e7b5dec502b9f8c6 to your computer and use it in GitHub Desktop.
Save szilard/b9e1e7b5dec502b9f8c6 to your computer and use it in GitHub Desktop.
Training random forest in Spark / MLlib
spark-1.3.0-bin-hadoop2.4/bin/spark-shell --driver-memory 100G --executor-memory 100G
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
val size="1"
val d_train = sc.textFile("train-1hot-"+size+"m.csv").map({ line =>
val vv = line.split(',').map(_.toDouble)
val label = vv(0)
val features = Vectors.dense(vv.slice(1,vv.size))
LabeledPoint(label, features)
})
val d_test = sc.textFile("test-1hot-"+size+"m.csv").map({ line =>
val vv = line.split(',').map(_.toDouble)
val label = vv(0)
val features = Vectors.dense(vv.slice(1,vv.size))
LabeledPoint(label, features)
})
d_train.cache()
d_test.cache()
d_test.count()
val now = System.nanoTime
d_train.count()
( System.nanoTime - now )/1e9
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 10
val featureSubsetStrategy = "sqrt"
val impurity = "gini"
val maxDepth = 30 // default 5 too little, but higher values slow (and max is 30)
val maxBins = 100
val now = System.nanoTime
val model = RandomForest.trainClassifier(d_train, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
( System.nanoTime - now )/1e9
val scoreAndLabels = d_test.map { point =>
//val score = model.predict(point.features) // does not work as it returns 0/1
val score = model.trees.map(tree => tree.predict(point.features)).filter(_>0).size.toDouble/model.numTrees
(score, point.label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
metrics.areaUnderROC()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment