Skip to content

Instantly share code, notes, and snippets.

@gcsfred
Created November 15, 2018 12:38
Show Gist options
  • Save gcsfred/9dc7b9a6ce25bb54815e90af575a789c to your computer and use it in GitHub Desktop.
Save gcsfred/9dc7b9a6ce25bb54815e90af575a789c to your computer and use it in GitHub Desktop.
OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
# ...
def one_hot_encode(_df, input_column, output_column):
indexer = StringIndexer(inputCol=input_column, outputCol=input_column+"_indexed", handleInvalid='skip')
_model = indexer.fit(_df)
_td = _model.transform(_df)
encoder = OneHotEncoder(inputCol=input_column+"_indexed", outputCol=output_column, dropLast=True)
_df2 = encoder.transform(_td)
return _df2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment