Skip to content

Instantly share code, notes, and snippets.

@melissakou
Created October 1, 2021 15:47
Show Gist options
  • Save melissakou/f6c3c91a3a1a952f623cd0a7418ea5cd to your computer and use it in GitHub Desktop.
Save melissakou/f6c3c91a3a1a952f623cd0a7418ea5cd to your computer and use it in GitHub Desktop.
def key_salting_join(left_df, right_df, left_key, right_key, how, coarseness):
""" Implementation of key salting join.
Args:
left_df (spark.DataFrame): left dataframe to be joined.
right_df (spark.DataFrame): right dataframe to be joined.
left_key (str): key in left dataframe to join on.
right_key (str): key in left dataframe to join on.
how (str): join type, argument for spark.pyspark.sql.DataFrame.join. (https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.join.html)
coarseness (int): how many random values are to be added to the original key.
the larger this value, the key distribution becomes more uniform, but explode the data size.
Returns:
spark.DataFrame: joined dataframe.
"""
# add random value to original key
left_df = left_df.withColumn("dummy", F.monotonically_increasing_id() % coarseness) \
.withColumn("salted_key", F.concat_ws("-", F.col(left_key), F.col("dummy"))) \
.drop("dummy")
# explode the original key
right_df = right_df.withColumn("dummy", F.explode(F.array([F.lit(i) for i in range(coarseness)]))) \
.withColumn("salted_key", F.concat_ws("-", F.col(right_key), F.col("dummy"))) \
.drop("dummy")
# join with the artifical salting key
joined_df = left_df.join(right_df, on="salted_key", how=how).drop("salted_key")
return joined_df
key_salting_join(left_df=df1, right_df=df2, left_key="key", right_key="key", how="left", coarseness=50).count()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment