Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:42
Show Gist options
  • Save colbyford/f1f621cf45c6a62a9269348352f6609f to your computer and use it in GitHub Desktop.
Save colbyford/f1f621cf45c6a62a9269348352f6609f to your computer and use it in GitHub Desktop.
SparkML Random Forest Regression Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Random Forest Regression Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
# Create an initial RandomForest model.
rf = RandomForestRegressor(labelCol="label", featuresCol="features")
# Evaluate model
rfevaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="rmse")
# Create ParamGrid for Cross Validation
rfparamGrid = (ParamGridBuilder()
#.addGrid(rf.maxDepth, [2, 5, 10, 20, 30])
.addGrid(rf.maxDepth, [2, 5, 10])
#.addGrid(rf.maxBins, [10, 20, 40, 80, 100])
.addGrid(rf.maxBins, [5, 10, 20])
#.addGrid(rf.numTrees, [5, 20, 50, 100, 500])
.addGrid(rf.numTrees, [5, 20, 50])
.build())
# Create 5-fold CrossValidator
rfcv = CrossValidator(estimator = rf,
estimatorParamMaps = rfparamGrid,
evaluator = rfevaluator,
numFolds = 5)
# Run cross validations.
rfcvModel = rfcv.fit(train)
print(rfcvModel)
# Use test set here so we can measure the accuracy of our model on new data
rfpredictions = rfcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('RMSE:', rfevaluator.evaluate(rfpredictions))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment