This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.mllib.evaluation import MultilabelMetrics | |
scoreAndLabels = sc.parallelize([ | |
([0.0, 1.0, 1.0], [1.0, 0.0, 0.0]), | |
([0.0, 0.0, 1.0], [0.0, 0.0, 1.0]), | |
([1.0, 0.0, 0.0], [1.0, 0.0, 0.0]), | |
([0.0, 1.0, 0.0], [0.0, 1.0, 0.0])]) | |
metrics = MultilabelMetrics(scoreAndLabels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def transform(self, df): | |
""" | |
Make predictions for each instance | |
:param df: dataframe with a `features` column | |
:type df: pyspark.sql.DataFrame | |
:return: prediction vectors for each instance | |
:rtype: pyspark.sql.DataFrame | |
""" | |
model_preds = [] | |
for i, model in enumerate(self.fitted_mlc.stages[0].models): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml.linalg import DenseVector | |
from pyspark.ml.classification import OneVsRest, OneVsRestModel | |
from pyspark.ml.pipeline import Pipeline, PipelineModel | |
from pyspark.sql.functions import ( | |
udf, | |
lit, | |
monotonically_increasing_id, | |
collect_list, | |
desc | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def preprocess(df): | |
""" | |
Prepare data for multi-label classifier | |
:param df: dataframe with features and labels for each instance | |
:type df: pyspark.sql.DataFrame | |
""" | |
cat_cols = df.schema.fieldNames()[:-1] | |
indexers = [ | |
StringIndexer( |