Created
October 1, 2021 15:47
-
-
Save melissakou/f6c3c91a3a1a952f623cd0a7418ea5cd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def key_salting_join(left_df, right_df, left_key, right_key, how, coarseness): | |
""" Implementation of key salting join. | |
Args: | |
left_df (spark.DataFrame): left dataframe to be joined. | |
right_df (spark.DataFrame): right dataframe to be joined. | |
left_key (str): key in left dataframe to join on. | |
right_key (str): key in left dataframe to join on. | |
how (str): join type, argument for spark.pyspark.sql.DataFrame.join. (https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.join.html) | |
coarseness (int): how many random values are to be added to the original key. | |
the larger this value, the key distribution becomes more uniform, but explode the data size. | |
Returns: | |
spark.DataFrame: joined dataframe. | |
""" | |
# add random value to original key | |
left_df = left_df.withColumn("dummy", F.monotonically_increasing_id() % coarseness) \ | |
.withColumn("salted_key", F.concat_ws("-", F.col(left_key), F.col("dummy"))) \ | |
.drop("dummy") | |
# explode the original key | |
right_df = right_df.withColumn("dummy", F.explode(F.array([F.lit(i) for i in range(coarseness)]))) \ | |
.withColumn("salted_key", F.concat_ws("-", F.col(right_key), F.col("dummy"))) \ | |
.drop("dummy") | |
# join with the artifical salting key | |
joined_df = left_df.join(right_df, on="salted_key", how=how).drop("salted_key") | |
return joined_df | |
key_salting_join(left_df=df1, right_df=df2, left_key="key", right_key="key", how="left", coarseness=50).count() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment