Created
May 12, 2015 17:48
-
-
Save petergarbers/0dc133d68918861cd879 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def magicMultiLabelClassification(): Unit = { | |
val conf = new SparkConf().setAppName("Simple Application").setMaster("local[2]") | |
val sc = new SparkContext(conf) | |
val data = MLUtils.loadLibSVMFile(sc, "test3.txt") | |
val splits = data.randomSplit(Array(0.7, 0.3)) | |
val (trainingData, testData) = (splits(0), splits(1)) | |
//Inputs for random forest | |
val numClasses = 2 | |
val categoricalFeaturesInfo = Map[Int, Int]() | |
val numTrees = 3 // Use more in practice. | |
val featureSubsetStrategy = "auto" // Let the algorithm choose. | |
val impurity = "gini" | |
val maxDepth = 4 | |
val maxBins = 32 | |
val labels = trainingData.map(d => d.label).take(1000).distinct // Take is a hack to allow me to get distinct. | |
val groupedRdds = labels.map { l => trainingData.filter(m => m.label == l) } | |
val models = groupedRdds.map(rdd => RandomForest.trainClassifier(rdd, numClasses, categoricalFeaturesInfo, | |
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)) | |
val labelsAndPredictions = testData.map { point => models.map(model => (point.label, model.predict(point.features))) } | |
// I could then filter here for predictions above 0.5 and return the labels | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment