Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:41
Show Gist options
  • Save colbyford/1f47a90fe0c55b4414cbd0c784fe3a67 to your computer and use it in GitHub Desktop.
Save colbyford/1f47a90fe0c55b4414cbd0c784fe3a67 to your computer and use it in GitHub Desktop.
SparkML Naïve Bayes Script with Cross-Validation and Parameter Sweep
########################################
## Title: Spark MLlib Naïve Bayes Classification Script, with Cross-Validation and Parameter Sweep
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import BinaryClassificationMetrics
# Create initial Naïve Bayes model
nb = NaiveBayes(labelCol="label", featuresCol="features")
# Create ParamGrid for Cross Validation
nbparamGrid = (ParamGridBuilder()
.addGrid(nb.smoothing, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
.build())
# Evaluate model
nbevaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
# Create 5-fold CrossValidator
nbcv = CrossValidator(estimator = nb,
estimatorParamMaps = nbparamGrid,
evaluator = nbevaluator,
numFolds = 5)
# Run cross validations
nbcvModel = nbcv.fit(train)
print(nbcvModel)
# Use test set here so we can measure the accuracy of our model on new data
nbpredictions = nbcvModel.transform(test)
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
print('Accuracy:', nbevaluator.evaluate(nbpredictions))
print('AUC:', BinaryClassificationMetrics(nbpredictions['label','prediction'].rdd).areaUnderROC)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment