Skip to content

Instantly share code, notes, and snippets.

@fdeantoni
Created May 13, 2015 05:06
Show Gist options
  • Save fdeantoni/92d91b636544e0fc95af to your computer and use it in GitHub Desktop.
Save fdeantoni/92d91b636544e0fc95af to your computer and use it in GitHub Desktop.
Spark Gradient Boosted Tree
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils
object GBT {
def main(args: Array[String]): Unit = {
val conf = new SparkConf(false) // skip loading external settings
.setMaster("local[4]") // run locally with enough threads
.setAppName("firstSparkApp")
.set("spark.logConf", "true")
.set("spark.driver.host", "localhost")
val sc = new SparkContext(conf)
sample(sc)
}
def sample(sc: SparkContext): Unit = {
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "src/main/resources/sample_libsvm_data.txt")
// Split data into training/test sets
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a GradientBoostedTrees model.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more in practice
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
// Evaluate model on test instances and compute test error
val testErr = testData.map { point =>
val prediction = model.predict(point.features)
if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned GBT model:n" + model.toDebugString)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment