Skip to content

Instantly share code, notes, and snippets.

@petergarbers
Created May 12, 2015 17:48
Show Gist options
  • Save petergarbers/0dc133d68918861cd879 to your computer and use it in GitHub Desktop.
Save petergarbers/0dc133d68918861cd879 to your computer and use it in GitHub Desktop.
def magicMultiLabelClassification(): Unit = {
val conf = new SparkConf().setAppName("Simple Application").setMaster("local[2]")
val sc = new SparkContext(conf)
val data = MLUtils.loadLibSVMFile(sc, "test3.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
//Inputs for random forest
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3 // Use more in practice.
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "gini"
val maxDepth = 4
val maxBins = 32
val labels = trainingData.map(d => d.label).take(1000).distinct // Take is a hack to allow me to get distinct.
val groupedRdds = labels.map { l => trainingData.filter(m => m.label == l) }
val models = groupedRdds.map(rdd => RandomForest.trainClassifier(rdd, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins))
val labelsAndPredictions = testData.map { point => models.map(model => (point.label, model.predict(point.features))) }
// I could then filter here for predictions above 0.5 and return the labels
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment