This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sales = spark.read.option("header", True).csv("sales_train_evaluation.csv") | |
# select d_1~d_100 and turn into long format | |
cols = ["d_" + str(i) for i in range(1, 100)] | |
sales = sales \ | |
.selectExpr("id", "item_id", "dept_id", "cat_id", "store_id", "state_id", | |
"stack({}, {}) as (d, amount)".format(len(cols), ', '.join(("'{}', {}".format(i, i) for i in cols)))) \ | |
.cache() | |
# group by state_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sales = spark.read.option("header", True).csv("sales_train_evaluation.csv") | |
# select d_1~d_100 and turn into long format | |
cols = ["d_" + str(i) for i in range(1, 100)] | |
sales = sales \ | |
.selectExpr("id", "item_id", "dept_id", "cat_id", "store_id", "state_id", | |
"stack({}, {}) as (d, amount)".format(len(cols), ', '.join(("'{}', {}".format(i, i) for i in cols)))) | |
# group by state_id | |
groupby_state = sales \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sales = spark.read.option("header", True).csv("sales_train_evaluation.csv") | |
cols = sales.columns[6:] | |
groupby_state = sales \ | |
.selectExpr("id", "item_id", "dept_id", "cat_id", "store_id", "state_id", | |
"stack({}, {}) as (d, amount)".format(len(cols), ', '.join(("'{}', {}".format(i, i) for i in cols)))) \ | |
.groupBy("state_id") \ | |
.agg(F.sum("amount").alias('amt_tot')) \ | |
.orderBy(F.col("amt_tot").desc()) | |
groupby_state.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def key_salting_join(left_df, right_df, left_key, right_key, how, coarseness): | |
""" Implementation of key salting join. | |
Args: | |
left_df (spark.DataFrame): left dataframe to be joined. | |
right_df (spark.DataFrame): right dataframe to be joined. | |
left_key (str): key in left dataframe to join on. | |
right_key (str): key in left dataframe to join on. | |
how (str): join type, argument for spark.pyspark.sql.DataFrame.join. (https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.join.html) | |
coarseness (int): how many random values are to be added to the original key. | |
the larger this value, the key distribution becomes more uniform, but explode the data size. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rdd2 = sc.parallelize(range(20)).map(lambda x: [x]) | |
df2 = rdd2.toDF(schema=["index"]) \ | |
.withColumn("key", F.when(F.col("index") < 2, F.concat_ws("-", F.lit("key"), F.col("index"))).otherwise("key-2")) \ | |
.cache() | |
df2.groupBy("key").count().show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rdd1 = sc.parallelize(range(int(1e8)), 200).map(lambda x: [x]) | |
df1 = rdd1.toDF(schema=["index"]) \ | |
.withColumn("key", F.when(F.col("index") < 2, F.concat_ws("-", F.lit("key"), F.col("index"))).otherwise("key-2")) \ | |
.cache() | |
df1.groupBy("key").count().show() |