Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active January 10, 2024 02:36
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save colbyford/184097b0ec37b2b35667dab2da57d349 to your computer and use it in GitHub Desktop.
Save colbyford/184097b0ec37b2b35667dab2da57d349 to your computer and use it in GitHub Desktop.
SparkML Linear Regression Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Linear Regression Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
# Create initial LinearRegression model
lr = LinearRegression(labelCol="label", featuresCol="features")
# Create ParamGrid for Cross Validation
lrparamGrid = (ParamGridBuilder()
.addGrid(lr.regParam, [0.001, 0.01, 0.1, 0.5, 1.0, 2.0])
# .addGrid(lr.regParam, [0.01, 0.1, 0.5])
.addGrid(lr.elasticNetParam, [0.0, 0.25, 0.5, 0.75, 1.0])
# .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
.addGrid(lr.maxIter, [1, 5, 10, 20, 50])
# .addGrid(lr.maxIter, [1, 5, 10])
.build())
# Evaluate model
lrevaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="rmse")
# Create 5-fold CrossValidator
lrcv = CrossValidator(estimator = lr,
estimatorParamMaps = lrparamGrid,
evaluator = lrevaluator,
numFolds = 5)
# Run cross validations
lrcvModel = lrcv.fit(train)
print(lrcvModel)
# Get Model Summary Statistics
lrcvSummary = lrcvModel.bestModel.summary
print("Coefficient Standard Errors: " + str(lrcvSummary.coefficientStandardErrors))
print("P Values: " + str(lrcvSummary.pValues)) # Last element is the intercept
# Use test set here so we can measure the accuracy of our model on new data
lrpredictions = lrcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('RMSE:', lrevaluator.evaluate(lrpredictions))
@shubham12tomar
Copy link

Line 39 has correction:
lrcvSummary = lrcvModel.bestModel.stages[-1].summary

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment