Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 21, 2022 15:59
Show Gist options
  • Save colbyford/031a4393b8a74ad658bab81abd30a1ea to your computer and use it in GitHub Desktop.
Save colbyford/031a4393b8a74ad658bab81abd30a1ea to your computer and use it in GitHub Desktop.
SparkML Random Forest Classification Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Random Forest Classification Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import BinaryClassificationMetrics
#from mmlspark import ComputeModelStatistics
# Create an initial RandomForest model.
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
# Evaluate model
rfevaluator = BinaryClassificationEvaluator()
# Create ParamGrid for Cross Validation
rfparamGrid = (ParamGridBuilder()
.addGrid(rf.maxDepth, [2, 5, 10, 20, 30])
.addGrid(rf.maxBins, [10, 20, 40, 80, 100])
.addGrid(rf.numTrees, [5, 20, 50, 100, 500])
.build())
# Create 5-fold CrossValidator
rfcv = CrossValidator(estimator = rf,
estimatorParamMaps = rfparamGrid,
evaluator = rfevaluator,
numFolds = 5)
# Run cross validations.
rfcvModel = rfcv.fit(train)
print(rfcvModel)
# Use test set here so we can measure the accuracy of our model on new data
rfpredictions = rfcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('Accuracy:', rfevaluator.evaluate(rfpredictions))
print('AUC:', BinaryClassificationMetrics(rfpredictions['label','prediction'].rdd).areaUnderROC)
#ComputeModelStatistics().transform(rfpredictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment