Skip to content

Instantly share code, notes, and snippets.

@paul-brebner
Created September 27, 2017 06:43
Show Gist options
  • Save paul-brebner/71174cae87887a94ad4707e9a8f1741c to your computer and use it in GitHub Desktop.
Save paul-brebner/71174cae87887a94ad4707e9a8f1741c to your computer and use it in GitHub Desktop.
Simple Spark MLLib Decision Tree Example (RDD)
package spark1;
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkContext;
public class DecisionTreeBlog2 {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Java Decision Tree Classification Example");
conf.setMaster("local");
SparkContext sc = new SparkContext(conf);
String path = "WillTheMonolithReact.txt";
// check where best to cache? Here, there, or everywhere?
// here as we count it later (more than once).
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD().cache();
// how many samples do we have?
// count() is an action so will cause everything up to here to be executed for the 1st time
long n = data.count();
System.out.println("RDD size = " + n);
// count number of positive and negative examples
// functions to return true if label == 0 or label == 1
// LabeledPoint is a tuple of (label, features).
Function<LabeledPoint, Boolean> label0 = row -> (row.label() == 0.0);
Function<LabeledPoint, Boolean> label1 = row -> (row.label() == 1.0);
// Note: Only need to count positives!
Double neg = (double) data.filter(label0).count();
Double pos = (double) data.filter(label1).count();
System.out.println("pos examples = " + pos);
System.out.println("neg examples = " + neg);
System.out.println("probability of positive example = " + pos/(double)n);
System.out.println("probability of negative example = " + neg/(double)n);
// Split sample RDD into two sets, 60% training data, 40% testing data. 11 is a seed.
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> trainingData = splits[0].cache(); // cache the data
// should we cache testData as well??
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters for DecisionTree learning.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini"; // or “entropy”
Integer maxDepth = 5;
Integer maxBins = 32;
// Train DecisionTree model
// org.apache.spark.mllib.tree.DecisionTree
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
// Evaluate the model on the testData.
// For every example in testData, p, replace it by a Tuple of (predicted category, labelled category)
// E.g. (1.0,0.0) (0.0,0.0) (0.0,0.0) (0.0,1.0)
JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(p ->
new Tuple2<>(model.predict(p.features()), p.label()));
// Get evaluation metrics.
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());
// Decision Trees don't have probabilities so only 1 and 0 thresholds.
System.out.println("Thresholds: " + metrics.thresholds());
// Precision by threshold
JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
System.out.println("Precision by threshold: " + precision.collect());
// Recall by threshold
JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
System.out.println("Recall by threshold: " + recall.collect());
// F by threshold
JavaRDD<Tuple2<Object, Object>> f = metrics.fMeasureByThreshold().toJavaRDD();
System.out.println("F by threshold: " + f.collect());
}
}
@paul-brebner
Copy link
Author

paul-brebner commented Jun 19, 2019 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment