Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:42
Show Gist options
  • Save colbyford/daa4508f6d8d94a405e7bd3a50c5ed77 to your computer and use it in GitHub Desktop.
Save colbyford/daa4508f6d8d94a405e7bd3a50c5ed77 to your computer and use it in GitHub Desktop.
SparkML Decision Tree Regression Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Decision Tree Regression Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
# Create initial Decision Tree Model
dt = DecisionTreeRegressor(labelCol="label", featuresCol="features")
# Create ParamGrid for Cross Validation
dtparamGrid = (ParamGridBuilder()
.addGrid(dt.maxDepth, [2, 5, 10, 20, 30])
#.addGrid(dt.maxDepth, [2, 5, 10])
.addGrid(dt.maxBins, [10, 20, 40, 80, 100])
#.addGrid(dt.maxBins, [10, 20])
.build())
# Evaluate model
dtevaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="rmse")
# Create 5-fold CrossValidator
dtcv = CrossValidator(estimator = dt,
estimatorParamMaps = dtparamGrid,
evaluator = dtevaluator,
numFolds = 5)
# Run cross validations
dtcvModel = dtcv.fit(train)
print(dtcvModel)
# Use test set here so we can measure the accuracy of our model on new data
dtpredictions = dtcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('RMSE:', dtevaluator.evaluate(dtpredictions))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment