Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:40
Show Gist options
  • Save colbyford/7758088502211daa90dbc1b51c408762 to your computer and use it in GitHub Desktop.
Save colbyford/7758088502211daa90dbc1b51c408762 to your computer and use it in GitHub Desktop.
SparkML Decision Tree Classification Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Decision Tree Classification Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import BinaryClassificationMetrics
#from mmlspark import ComputeModelStatistics
# Create initial Decision Tree Model
dt = DecisionTreeClassifier(labelCol="label", featuresCol="features", maxDepth=2)
# Create ParamGrid for Cross Validation
dtparamGrid = (ParamGridBuilder()
.addGrid(dt.maxDepth, [2, 5, 10, 20, 30])
.addGrid(dt.maxBins, [10, 20, 40, 80, 100])
.build())
# Evaluate model
dtevaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
# Create 5-fold CrossValidator
dtcv = CrossValidator(estimator = dt,
estimatorParamMaps = dtparamGrid,
evaluator = dtevaluator,
numFolds = 5)
# Run cross validations
dtcvModel = dtcv.fit(train)
print(dtcvModel)
# Use test set here so we can measure the accuracy of our model on new data
dtpredictions = dtcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('Accuracy:', dtevaluator.evaluate(dtpredictions))
print('AUC:', BinaryClassificationMetrics(dtpredictions['label','prediction'].rdd).areaUnderROC)
#ComputeModelStatistics().transform(dtpredictions)
@ed765super
Copy link

train and test come about from a randomSplit([0.7, 0.3])?

yep!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment