Skip to content

Instantly share code, notes, and snippets.

@karenyyng
Created September 12, 2015 22:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save karenyyng/cf61ae655b032f754bfb to your computer and use it in GitHub Desktop.
Save karenyyng/cf61ae655b032f754bfb to your computer and use it in GitHub Desktop.
PySpark RandomForestClassifier doesn't return `rawPrediction` column

Example for using ML pipeline

Logistic regression works for BinaryClassificationEvaluator

from __future__ import print_function
from pyspark.mllib.linalg import Vectors
dataset = sqlContext.createDataFrame(
    [(Vectors.dense([0.0]), 0.0),
     (Vectors.dense([0.4]), 1.0),
     (Vectors.dense([0.5]), 0.0),
     (Vectors.dense([0.6]), 1.0),
     (Vectors.dense([1.0]), 1.0)] * 10,
    ["features", "label"])
evaluator = BinaryClassificationEvaluator()
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
                    evaluator=evaluator)
cvModel = cv.fit(dataset)
transformed_dataset = cvModel.transform(dataset)
print (evaluator.evaluate(transformed_dataset))
# outputs: 0.8333333333

print (transformed_dataset)
# outputs: DataFrame[features: vector, label: double, rawPrediction: vector, probability: vector, prediction: double]

The rawPrediction column and the probability columns are both present for outputs from LogisticRegression.

RandomForestClassifier does not return the appropriate columns

rf = (RandomForestClassifier()
      .setFeaturesCol("features")
      .setLabelCol("ix_label")
     )

stringIndexer = StringIndexer(inputCol="label", 
                              outputCol="ix_label")

ix_model = stringIndexer.fit(dataset)
ix_dataset = ix_model.transform(dataset)

rf_model = rf.fit(ix_dataset)
transformed_dataset = rf_model.transform(ix_dataset)

print (transformed_dataset)
# outputs: DataFrame[features: vector, label: double, ix_label: double, prediction: double]

The rawPrediction column and the probability columns are missing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment