Skip to content

Instantly share code, notes, and snippets.

@colbyford
Last active September 23, 2022 16:41
Show Gist options
  • Save colbyford/83978917799dbcab6293521a60f29e94 to your computer and use it in GitHub Desktop.
Save colbyford/83978917799dbcab6293521a60f29e94 to your computer and use it in GitHub Desktop.
SparkML Data Preparation Steps for Binary Classification Models
########################################
## Title: Spark MLlib Classification Data Prep Script
## Language: PySpark
## Author: Colby T. Ford, Ph.D.
########################################
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, OneHotEncoderEstimator, StringIndexer, VectorAssembler
label = "dependentvar"
categoricalColumns = ["col1",
"col2"]
numericalColumns = ["num1",
"num2"]
#categoricalColumnsclassVec = ["col1classVec",
# "col2classVec"]
categoricalColumnsclassVec = [c + "classVec" for c in categoricalColumns]
for categoricalColumn in categoricalColumns:
# Category Indexing with StringIndexer
stringIndexer = StringIndexer(inputCol=categoricalColumn, outputCol = categoricalColumn+"Index").setHandleInvalid("skip")
dataset = stringIndexer.fit(dataset).transform(dataset)
# Use OneHotEncoder to convert categorical variables into binary SparseVectors
encoder = OneHotEncoder(inputCol=categoricalColumn+"Index", outputCol=categoricalColumn+"classVec")
dataset = encoder.transform(dataset)
# Convert label into label indices using the StringIndexer
label_stringIndexer = StringIndexer(inputCol = label, outputCol = "label").setHandleInvalid("skip")
dataset = label_stringIndexer.fit(dataset).transform(dataset)
# Transform all features into a vector using VectorAssembler
#assemblerInputs = map(lambda c: c + "classVec", categoricalColumns) + numericalColumns
assemblerInputs = categoricalColumnsclassVec + numericalColumns
print(assemblerInputs)
assembler = VectorAssembler(inputCols = assemblerInputs, outputCol="features")
dataset = assembler.transform(dataset.na.drop())
# Keep relevant columns
selectedcols = ["label", "features"]# + cols
dataset = dataset.select(selectedcols)
##dataset.printSchema()
dataset.show()
#Split Data into Train and Test sets
train, test = dataset.randomSplit([0.75, 0.25], seed=1337)
display(train)
################# PIPELINE VERSION #######################
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, OneHotEncoderEstimator, StringIndexer, VectorAssembler
label = "dependentvar"
categoricalColumns = ["col1",
"col2"]
numericalColumns = ["num1",
"num2"]
#categoricalColumnsclassVec = ["col1classVec",
# "col2classVec"]
categoricalColumnsclassVec = [c + "classVec" for c in categoricalColumns]
stages = []
for categoricalColumn in categoricalColumns:
print(categoricalColumn)
# Category Indexing with StringIndexer
stringIndexer = StringIndexer(inputCol=categoricalColumn, outputCol = categoricalColumn+"Index").setHandleInvalid("skip")
# Use OneHotEncoder to convert categorical variables into binary SparseVectors
encoder = OneHotEncoder(inputCol=categoricalColumn+"Index", outputCol=categoricalColumn+"classVec")
# Add stages. These are not run here, but will run all at once later on.
stages += [stringIndexer, encoder]
# Convert label into label indices using the StringIndexer
label_stringIndexer = StringIndexer(inputCol = label, outputCol = "label").setHandleInvalid("skip")
stages += [label_stringIndexer]
# Transform all features into a vector using VectorAssembler
assemblerInputs = categoricalColumnsclassVec + numericalColumns
assembler = VectorAssembler(inputCols = assemblerInputs, outputCol="features")
stages += [assembler]
prepPipeline = Pipeline().setStages(stages)
pipelineModel = prepPipeline.fit(dataset)
dataset = pipelineModel.transform(dataset)
## Save Transformation Pipeline
pipelineModel.save("/mnt/<YOURMOUNTEDSTORAGE>/pipeline")
display(dbutils.fs.ls("/mnt/<YOURMOUNTEDSTORAGE>/pipeline"))
## Read in Transformation Pipeline
from pyspark.ml import PipelineModel
pipelineModel = PipelineModel.load("/mnt/<YOURMOUNTEDSTORAGE>/pipeline")
dataset = pipelineModel.transform(dataset)
display(dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment