Execution plan of spark on bucketed data-sets, and verify if it is smart enough to avoid wide dependency.
PS: When trying things in spark-shell, make a note that, for small datasets, the join would be probably be broadcast exchange in physical execution plan by default. Example:
./spark-shell
val r = List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y")
r.write.mode("overwrite").format("parquet").saveAsTable("hello1")
r.write.mode("overwrite").format("parquet").saveAsTable("hello2")
val x1 = spark.table("hello1")
val x2 = spark.table("hello2")
scala> x1.join(x2, "id")
res1: org.apache.spark.sql.DataFrame = [id: int, x: int ... 3 more fields]
scala> res1.explain
== Physical Plan ==
*(2) Project [id#22, x#23, y#24, x#32, y#33]
+- *(2) BroadcastHashJoin [id#22], [id#31], Inner, BuildRight
:- *(2) Project [id#22, x#23, y#24]
: +- *(2) Filter isnotnull(id#22)
: +- *(2) FileScan parquet default.hello[id#22,x#23,y#24] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/h..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)))
+- *(1) Project [id#31, x#32, y#33]
+- *(1) Filter isnotnull(id#31)
+- *(1) FileScan parquet default.hello[id#31,x#32,y#33] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/h..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>
You can see the second table is going through an BroadcastExchange
step of broadcasting of sending data to make the join in mapper/executor side itself. The first table takes part in broadcast hashjoin, and doesn't have an Exchange
for itself for obvious reasons.
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4))
.toDF("id", "x", "y").write.mode("overwrite").format("parquet").saveAsTable("hello1")
List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4))
.toDF("id", "x", "y").write.mode("overwrite").format("parquet").saveAsTable("hello2")
val x1 = spark.table("hello1")
val x2 = spark.table("hello2")
scala> x1.join(x2, "id")
res3: org.apache.spark.sql.DataFrame = [id: int, x: int ... 3 more fields]
scala> res3.explain
== Physical Plan ==
*(5) Project [id#32, x#33, y#34, x#39, y#40]
+- *(5) SortMergeJoin [id#32], [id#38], Inner
:- *(2) Sort [id#32 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#32, 200)
: +- *(1) Project [id#32, x#33, y#34]
: +- *(1) Filter isnotnull(id#32)
: +- *(1) FileScan parquet default.hello1[id#32,x#33,y#34] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/h..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>
+- *(4) Sort [id#38 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#38, 200)
+- *(3) Project [id#38, x#39, y#40]
+- *(3) Filter isnotnull(id#38)
+- *(3) FileScan parquet default.hello2[id#38,x#39,y#40] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/h..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>
That's an exchange step on each table, along with a sort to get them into the same partition - that corresponds to the shuffle that happens on both sides to get all the keys in reducer to join the results.
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.bucketBy(2, "id").mode("overwrite").format("parquet").saveAsTable("bucketedhello")
19/08/02 02:19:00 WARN HiveExternalCatalog: Persisting bucketed data source table `default`.`bucketedhello` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive.
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.bucketBy(2, "id").mode("overwrite").format("parquet").saveAsTable("bucketedhello2")
19/08/02 02:19:05 WARN HiveExternalCatalog: Persisting bucketed data source table `default`.`bucketedhello2` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive.
scala> val x1 = spark.table("bucketedhello")
x1: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> val x2 = spark.table("bucketedhello2")
x2: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> x1.join(x2, "id")
res9: org.apache.spark.sql.DataFrame = [id: int, x: int ... 3 more fields]
scala> res9.collect
res10: Array[org.apache.spark.sql.Row] = Array([2,3,4,3,4], [2,3,4,3,4], [2,3,4,3,4], [2,3,4,3,4], [1,2,3,2,3], [1,2,3,2,3], [1,2,3,3,4], [1,2,3,2,3], [1,2,3,2,3], [1,2,3,3,4], [1,3,4,2,3], [1,3,4,2,3], [1,3,4,3,4])
scala> res9.explain
== Physical Plan ==
*(3) Project [id#110, x#111, y#112, x#117, y#118]
+- *(3) SortMergeJoin [id#110], [id#116], Inner
:- *(1) Sort [id#110 ASC NULLS FIRST], false, 0
: +- *(1) Project [id#110, x#111, y#112]
: +- *(1) Filter isnotnull(id#110)
: +- *(1) FileScan parquet default.bucketedhello[id#110,x#111,y#112] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/b..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>, SelectedBucketsCount: 2 out of 2
+- *(2) Sort [id#116 ASC NULLS FIRST], false, 0
+- *(2) Project [id#116, x#117, y#118]
+- *(2) Filter isnotnull(id#116)
+- *(2) FileScan parquet default.bucketedhello2[id#116,x#117,y#118] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/b..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>, SelectedBucketsCount: 2 out of 2
We clearly see an optimisation here. Since both sides are "known partitions", spark avoided exchange on both sides, resulting in only a sort straight after the read. This narrow dependency is much faster compared to the previous step.
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.mode("overwrite").format("parquet").saveAsTable("nonbucket")
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y")
.write.bucketBy(2, "id").mode("overwrite").format("parquet").saveAsTable("bucket")
19/08/02 02:25:09 WARN HiveExternalCatalog: Persisting bucketed data source table `default`.`bucket` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive.
scala> val x1 = spark.table("bucket")
x1: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> val x2 = spark.table("nonbucket")
x2: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> x1.join(x2, "id")
res16: org.apache.spark.sql.DataFrame = [id: int, x: int ... 3 more fields]
scala> res16.explain
== Physical Plan ==
*(5) Project [id#162, x#163, y#164, x#169, y#170]
+- *(5) SortMergeJoin [id#162], [id#168], Inner
:- *(2) Sort [id#162 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#162, 200)
: +- *(1) Project [id#162, x#163, y#164]
: +- *(1) Filter isnotnull(id#162)
: +- *(1) FileScan parquet default.bucket[id#162,x#163,y#164] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/b..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>, SelectedBucketsCount: 2 out of 2
+- *(4) Sort [id#168 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#168, 200)
+- *(3) Project [id#168, x#169, y#170]
+- *(3) Filter isnotnull(id#168)
+- *(3) FileScan parquet default.nonbucket[id#168,x#169,y#170] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/n..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>
We see shuffle back again. This must be because of the constraint that only if number of partitions on both sides are the same, the bucketing trick works.
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.bucketBy(2, "id").mode("overwrite").format("parquet").saveAsTable("bucketfinal")
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").repartition(2, $"id")
res35: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, x: int ... 1 more field]
scala> val x1 = spark.table("bucketfinal")
x1: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> x1.join(res35, "id").explain
== Physical Plan ==
*(3) Project [id#348, x#349, y#350, x#368, y#369]
+- *(3) SortMergeJoin [id#348], [id#367], Inner
:- *(1) Sort [id#348 ASC NULLS FIRST], false, 0
: +- *(1) Project [id#348, x#349, y#350]
: +- *(1) Filter isnotnull(id#348)
: +- *(1) FileScan parquet default.bucketfinal[id#348,x#349,y#350] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/b..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>, SelectedBucketsCount: 2 out of 2
+- *(2) Sort [id#367 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#367, 2)
+- LocalTableScan [id#367, x#368, y#369]
scala>
We got away with one shuffle. The exchange is only 1 instead of 2.
But this may not happen always. What if the second data is too small?
scala> val x1 = spark.table("bucketfinal")
x1: org.apache.spark.sql.DataFrame = [id: int, x: int ... 1 more field]
scala> x1.join(res22, "id")
res33: org.apache.spark.sql.DataFrame = [id: int, x: int ... 3 more fields]
scala> res33.explain
== Physical Plan ==
*(4) Project [id#348, x#349, y#350, x#220, y#221]
+- *(4) SortMergeJoin [id#348], [id#219], Inner
:- *(2) Sort [id#348 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#348, 200)
: +- *(1) Project [id#348, x#349, y#350]
: +- *(1) Filter isnotnull(id#348)
: +- *(1) FileScan parquet default.bucketfinal[id#348,x#349,y#350] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/spark-warehouse/b..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:int,x:int,y:int>, SelectedBucketsCount: 2 out of 2
+- *(3) Sort [id#219 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#219, 200)
+- Exchange RoundRobinPartitioning(2)
+- LocalTableScan [id#219, x#220, y#221]
We are back with shuffle again !
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.partitionBy("x").mode("overwrite").format("parquet").options(Map("path" -> "data2")).saveAsTable("hello2")
scala> val x2 = spark.table("hello2")
x2: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> x2.groupBy("id").agg(count("id"))
res19: org.apache.spark.sql.DataFrame = [id: int, count(id): bigint]
scala> x2.groupBy("id").agg(count("id")).explain
== Physical Plan ==
*(2) HashAggregate(keys=[id#259], functions=[count(id#259)])
+- Exchange hashpartitioning(id#259, 200)
+- *(1) HashAggregate(keys=[id#259], functions=[partial_count(id#259)])
+- *(1) Project [id#259]
+- *(1) FileScan parquet default.hello2[id#259,x#261] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.partitionBy("x").bucketBy(10, "id").mode("overwrite").format("parquet").options(Map("path" -> "data2")).saveAsTable("hello2")
19/08/07 09:32:59 WARN HiveExternalCatalog: Persisting bucketed data source table `default`.`hello2` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive.
scala> val x2 = spark.table("hello2")
x2: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> x2.groupBy("id").agg(count("id")).explain
== Physical Plan ==
*(1) HashAggregate(keys=[id#303], functions=[count(id#303)])
+- *(1) HashAggregate(keys=[id#303], functions=[partial_count(id#303)])
+- *(1) Project [id#303]
+- *(1) FileScan parquet default.hello2[id#303,x#305] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>, SelectedBucketsCount: 10 out of 10
scala>
If the input is bucketed, the shuffle is avoided both in aggregations on "id" to form features, and also during the join of individual feature generations to form a wider table.
scala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.partitionBy("x").bucketBy(10, "id").mode("overwrite").format("parquet").options(Map("path" -> "data2")).saveAsTable("hello2")
19/08/10 14:41:55 WARN HiveExternalCatalog: Persisting bucketed data source table `default`.`hello2` into Hive metastore in Spark SQL specific format, which is NOT compatible with Hive.
scala> val x2 = spark.table("hello2")
x2: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> val x1 = spark.table("hello2")
x1: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
scala> x1.groupBy("id").agg(max("x"))
res4: org.apache.spark.sql.DataFrame = [id: int, max(x): int]
scala> x2.groupBy("id").agg(max("y"))
res5: org.apache.spark.sql.DataFrame = [id: int, max(y): int]
scala> res4.join(res5, Seq("id"), "outer")
res6: org.apache.spark.sql.DataFrame = [id: int, max(x): int ... 1 more field]
scala> res4.join(res5, Seq("id"), "outer").explain
== Physical Plan ==
*(3) Project [coalesce(id#44, id#74) AS id#82, max(x)#57, max(y)#64]
+- SortMergeJoin [id#44], [id#74], FullOuter
:- *(1) Sort [id#44 ASC NULLS FIRST], false, 0
: +- *(1) HashAggregate(keys=[id#44], functions=[max(x#46)])
: +- *(1) HashAggregate(keys=[id#44], functions=[partial_max(x#46)])
: +- *(1) FileScan parquet default.hello2[id#44,x#46] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>, SelectedBucketsCount: 10 out of 10
+- *(2) Sort [id#74 ASC NULLS FIRST], false, 0
+- *(2) HashAggregate(keys=[id#74], functions=[max(y#75)])
+- *(2) HashAggregate(keys=[id#74], functions=[partial_max(y#75)])
+- *(2) Project [id#74, y#75]
+- *(2) FileScan parquet default.hello2[id#74,y#75,x#76] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,y:int>, SelectedBucketsCount: 10 out of 10
scala>
If we don't bucket cost is shuffle happens as expected.
cala> List((1, 2, 3), (1, 2, 3), (2, 3, 4), (2, 3, 4), (1, 3, 4)).toDF("id", "x", "y").write.partitionBy("x").mode("overwrite").format("parquet").options(Map("path" -> "data2")).saveAsTable("hello2")
scala> val x1 = spark.table("hello2")
x1: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> val x2 = spark.table("hello2")
x2: org.apache.spark.sql.DataFrame = [id: int, y: int ... 1 more field]
scala> x1.groupBy("id").agg(max("x"))
res9: org.apache.spark.sql.DataFrame = [id: int, max(x): int]
scala> x2.groupBy("id").agg(max("y"))
res10: org.apache.spark.sql.DataFrame = [id: int, max(y): int]
scala> res9.join(res10, Seq("id"), "outer")
res11: org.apache.spark.sql.DataFrame = [id: int, max(x): int ... 1 more field]
scala> res9.join(res10, Seq("id"), "outer").explain
== Physical Plan ==
*(5) Project [coalesce(id#106, id#136) AS id#144, max(x)#119, max(y)#126]
+- SortMergeJoin [id#106], [id#136], FullOuter
:- *(2) Sort [id#106 ASC NULLS FIRST], false, 0
: +- *(2) HashAggregate(keys=[id#106], functions=[max(x#108)])
: +- Exchange hashpartitioning(id#106, 200)
: +- *(1) HashAggregate(keys=[id#106], functions=[partial_max(x#108)])
: +- *(1) FileScan parquet default.hello2[id#106,x#108] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
+- *(4) Sort [id#136 ASC NULLS FIRST], false, 0
+- *(4) HashAggregate(keys=[id#136], functions=[max(y#137)])
+- Exchange hashpartitioning(id#136, 200)
+- *(3) HashAggregate(keys=[id#136], functions=[partial_max(y#137)])
+- *(3) Project [id#136, y#137]
+- *(3) FileScan parquet default.hello2[id#136,y#137,x#138] Batched: true, Format: Parquet, Location: CatalogFileIndex[file:/Users/afsalthaj/Downloads/spark-2.4.3-bin-hadoop2.7/bin/data2], PartitionCount: 2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int,y:int>