-
-
Save yaravind/813dd46b2d32c6de8c415c636c8a46c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
from pyspark.sql.functions import * | |
from pyspark.sql import Row | |
from pyspark.sql.types import IntegerType | |
# Create the Spark session | |
spark = SparkSession.builder \ | |
.master("local") \ | |
.config("spark.sql.autoBroadcastJoinThreshold", -1) \ | |
.config("spark.executor.memory", "500mb") \ | |
.appName("Exercise1") \ | |
.getOrCreate() | |
# Read the source tables | |
products_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/products_parquet") | |
sales_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/sales_parquet") | |
sellers_table = spark.read.parquet("/Users/o60774/Downloads/product-sales/sellers_parquet") | |
# Step 1 - Check and select the skewed keys | |
# In this case we are retrieving the top 100 keys: these will be the only salted keys. | |
results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(1).collect() | |
# Step 2 - What we want to do is: | |
# a. Duplicate the entries that we have in the dimension table for the most common products, e.g. | |
# product_0 will become: product_0-1, product_0-2, product_0-3 and so on | |
# b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them | |
# will be replaced with product_0-1, others with product_0-2, etc.) | |
# Using the new "salted" key will unskew the join | |
# Let's create a dataset to do the trick | |
REPLICATION_FACTOR = 101 | |
l = [] | |
replicated_products = [] | |
for _r in results: | |
replicated_products.append(_r["product_id"]) | |
#print(replicated_products) | |
for _rep in range(0, REPLICATION_FACTOR): | |
l.append((_r["product_id"], _rep)) | |
#print(l) | |
rdd = spark.sparkContext.parallelize(l) | |
replicated_df = rdd.map(lambda x: Row(repl_product_id=x[0], replication=int(x[1]))) | |
replicated_df = spark.createDataFrame(replicated_df) | |
# Step 3: Generate the salted key | |
products_table = products_table \ | |
.join( \ | |
broadcast(replicated_df), \ | |
products_table["product_id"] == replicated_df["repl_product_id"], \ | |
"left" \ | |
).withColumn( \ | |
"salted_join_key", \ | |
when(replicated_df["replication"].isNull(), products_table["product_id"]) \ | |
.otherwise(concat(replicated_df["repl_product_id"], lit("-"), replicated_df["replication"])) \ | |
) | |
sales_table = sales_table \ | |
.withColumn( \ | |
"salted_join_key", \ | |
when(sales_table["product_id"].isin(replicated_products), \ | |
concat(sales_table["product_id"], lit("-"), \ | |
round(rand() * (REPLICATION_FACTOR - 1), 0).cast(IntegerType()))) \ | |
.otherwise(sales_table["product_id"])) | |
# Step 4: Finally let's do the join | |
result = sales_table \ | |
.join( \ | |
products_table, \ | |
sales_table["salted_join_key"] == products_table["salted_join_key"], \ | |
"inner" | |
) \ | |
.agg(avg(products_table["price"] * sales_table["num_pieces_sold"])) | |
print(result.show()) | |
print("Ok") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment