Skip to content

Instantly share code, notes, and snippets.

@vanheck
Last active March 13, 2024 13:50
Show Gist options
  • Save vanheck/bfcadf7396d765ddd2fff5f544fd7cf2 to your computer and use it in GitHub Desktop.
Save vanheck/bfcadf7396d765ddd2fff5f544fd7cf2 to your computer and use it in GitHub Desktop.
Pyspark how to avoid explode for group by top structure and in nested structure (code optimalisation)
import pyspark.sql.functions as F
import pyspark.sql.types as T
rows = [
{"id": 1, "typeId": 1, "items":[
{"itemType": 1,"flag": False,"event": None},
{"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
{"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
]},
{"id": 2, "typeId": 2, "items":None},
{"id": 3, "typeId": 1, "items":[
{"itemType": 1,"flag": False,"event": None},
{"itemType": 6,"flag": False,"event":[{"info1": ""}]},
{"itemType": 6,"flag": False,"event":None},
]},
{"id": 4, "typeId": 2, "items":[
{"itemType": 1,"flag": True,"event":[{"info1": ""}]},
]},
{"id": 5, "typeId": 3, "items":None},
]
schema = T.StructType([
T.StructField("id", T.IntegerType(), False),
T.StructField("typeId", T.IntegerType()),
T.StructField("items", T.ArrayType(T.StructType([
T.StructField("itemType", T.IntegerType()),
T.StructField("flag", T.BooleanType()),
T.StructField("event", T.ArrayType(T.StructType([
T.StructField("info1", T.StringType()),
]))),
])), True),
])
df = spark.createDataFrame(rows, schema)
# ============
layer1_groups = ["typeId"]
# get count for groups in top layer
totaldf = df.groupby(layer1_groups).agg(F.count(F.lit(1)).alias("requests"))
# join total count for each group - for later computation
df = df.join(totaldf, layer1_groups)
# to get in nested layer, need explode
exploded_df = df.withColumn("I", F.explode_outer("items")).select("*","I.*").drop("items","I")
exploded_df = exploded_df.withColumn("eSize", F.greatest(F.size("event"), F.lit(0)))
layer2_groups = ["itemType"]
each_requests = exploded_df.groupby(["id", *layer1_groups, *layer2_groups]).agg(
F.first("requests").alias("requests"),
F.count(F.lit(1)).alias("ItemCount"),
F.sum(F.col("flag").cast(T.ByteType())).alias("fItemCount"),
F.sum("eSize").alias("eCount"),
)
# results without layer1 "id" to obtain resulsts
requests_results = each_requests.groupby([*layer1_groups, *layer2_groups]).agg(
F.first("requests").alias("requests"),
F.count_if(F.col("ItemCount")>0).alias("requestsWithItems"),
F.count_if(F.col("fItemCount")>0).alias("requestsWith_fItems"),
F.sum("ItemCount").alias("ItemCount"),
F.sum("fItemCount").alias("fItemCount"),
F.sum("eCount").alias("eCount"),
).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment