Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Decision Tree
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
int numClasses = 4;
String impurity = "gini";
int maxDepth = 9;
int maxBins = 32;
// create model
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
Double testErrDT = 1.0 * predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / testData.count();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.