Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:38
Show Gist options
  • Save colbyford/f488ab3770f9da56f036fe8adbe2a9e5 to your computer and use it in GitHub Desktop.
Save colbyford/f488ab3770f9da56f036fe8adbe2a9e5 to your computer and use it in GitHub Desktop.
SparkML Logistic Regression Classification Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Logistic Regression Classification Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import BinaryClassificationMetrics
#from mmlspark import ComputeModelStatistics
# Create initial LogisticRegression model
lr = LogisticRegression(labelCol="label", featuresCol="features", maxIter=10)
# Create ParamGrid for Cross Validation
lrparamGrid = (ParamGridBuilder()
.addGrid(lr.regParam, [0.01, 0.1, 0.5, 1.0, 2.0])
.addGrid(lr.elasticNetParam, [0.0, 0.25, 0.5, 0.75, 1.0])
.addGrid(lr.maxIter, [1, 5, 10, 20, 50])
.build())
# Evaluate model
lrevaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
# Create 5-fold CrossValidator
lrcv = CrossValidator(estimator = lr,
estimatorParamMaps = lrparamGrid,
evaluator = lrevaluator,
numFolds = 5)
# Run cross validations
lrcvModel = lrcv.fit(train)
print(lrcvModel)
# Use test set here so we can measure the accuracy of our model on new data
lrpredictions = lrcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('Accuracy:', lrevaluator.evaluate(lrpredictions))
print('AUC:', BinaryClassificationMetrics(lrpredictions['label','prediction'].rdd).areaUnderROC)
#ComputeModelStatistics().transform(lrpredictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment