Skip to content

Instantly share code, notes, and snippets.

@yaravind
Forked from aialenti/exercise1.py
Last active May 1, 2020 23:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaravind/813dd46b2d32c6de8c415c636c8a46c4 to your computer and use it in GitHub Desktop.
Save yaravind/813dd46b2d32c6de8c415c636c8a46c4 to your computer and use it in GitHub Desktop.
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import Row
from pyspark.sql.types import IntegerType
# Create the Spark session
spark = SparkSession.builder \
.master("local") \
.config("spark.sql.autoBroadcastJoinThreshold", -1) \
.config("spark.executor.memory", "500mb") \
.appName("Exercise1") \
.getOrCreate()
# Read the source tables
products_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/products_parquet")
sales_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/sales_parquet")
sellers_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/sellers_parquet")
# Step 1 - Check and select the skewed keys
# In this case we are retrieving the top 100 keys: these will be the only salted keys.
results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(1).collect()
# Step 2 - What we want to do is:
# a. Duplicate the entries that we have in the dimension table for the most common products, e.g.
# product_0 will become: product_0-1, product_0-2, product_0-3 and so on
# b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them
# will be replaced with product_0-1, others with product_0-2, etc.)
# Using the new "salted" key will unskew the join
# Let's create a dataset to do the trick
REPLICATION_FACTOR = 101
l = []
replicated_products = []
for _r in results:
replicated_products.append(_r["product_id"])
#print(replicated_products)
for _rep in range(0, REPLICATION_FACTOR):
l.append((_r["product_id"], _rep))
#print(l)
rdd = spark.sparkContext.parallelize(l)
replicated_df = rdd.map(lambda x: Row(repl_product_id=x[0], replication=int(x[1])))
replicated_df = spark.createDataFrame(replicated_df)
# Step 3: Generate the salted key
products_table = products_table \
.join( \
broadcast(replicated_df), \
products_table["product_id"] == replicated_df["repl_product_id"], \
"left" \
).withColumn( \
"salted_join_key", \
when(replicated_df["replication"].isNull(), products_table["product_id"]) \
.otherwise(concat(replicated_df["repl_product_id"], lit("-"), replicated_df["replication"])) \
)
sales_table = sales_table \
.withColumn( \
"salted_join_key", \
when(sales_table["product_id"].isin(replicated_products), \
concat(sales_table["product_id"], lit("-"), \
round(rand() * (REPLICATION_FACTOR - 1), 0).cast(IntegerType()))) \
.otherwise(sales_table["product_id"]))
# Step 4: Finally let's do the join
result = sales_table \
.join( \
products_table, \
sales_table["salted_join_key"] == products_table["salted_join_key"], \
"inner"
) \
.agg(avg(products_table["price"] * sales_table["num_pieces_sold"]))
print(result.show())
print("Ok")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment