Skip to content

Instantly share code, notes, and snippets.

@mkaranasou
Created March 19, 2021 16:58
Show Gist options
  • Save mkaranasou/3c0720f7868c753b74214fd2272f8447 to your computer and use it in GitHub Desktop.
Save mkaranasou/3c0720f7868c753b74214fd2272f8447 to your computer and use it in GitHub Desktop.
Get feature permutations, one for each row, using pyspark
import pyspark
from pyspark.sql import functions as F
def get_features_permutations(
df: pyspark.DataFrame,
feature_names: list,
output_col='features_permutations'
):
"""
Creates a column for the ordered features and then shuffles it.
The result is a dataframe with a column `output_col` that contains:
[feat2, feat4, feat3, feat1],
[feat3, feat4, feat2, feat1],
[feat1, feat2, feat4, feat3],
...
"""
return df.withColumn(
output_col,
F.shuffle(
F.array(*[F.lit(f) for f in feature_names])
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment