Spark custom Physical plan optimisation to AQE force post shuffle coalesce for repartitionByRange with user specified partitions num.
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.execution.CoalescedPartitionSpec
import org.apache.spark.sql.internal.SQLConf
case class ForceRepartitionByRangeCoalescePartitions(session: SparkSession) extends Rule[SparkPlan] {
private def conf = session.sessionState.conf
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.coalesceShufflePartitionsEnabled) {
return plan
if (
|| plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined
) {
// If not all leaf nodes are query stages, it's not safe to reduce the number of
// shuffle partitions, because we may break the assumption that all children of a spark plan
// have same number of output partitions.
return plan
def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
case stage @ ShuffleQueryStageExec(_, ShuffleExchangeExec(RangePartitioning(_, _), _, _)) => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
val shuffleStages = collectShuffleStages(plan)
// Apply to specific ShuffleExchanges introduced by repartitionByRange.
if (shuffleStages.isEmpty) {
} else {
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
// we should skip it when calculating the `partitionStartIndices`.
val validMetrics = shuffleStages.flatMap(_.mapStats)
// We may have different pre-shuffle partition numbers, don't reduce shuffle partition number
// in that case. For example when we union fully aggregated data (data is arranged to a single
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions = => stats.bytesByPartitionId.length).distinct
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
if (log.isDebugEnabled) {
validMetrics.foreach { statistics =>
logDebug(s"Input shuffleId:${statistics.shuffleId} partitions:${statistics.bytesByPartitionId.length}")
logDebug(s"Max partition size :${statistics.bytesByPartitionId.max}")
logDebug(s"Min partition size :${statistics.bytesByPartitionId.min}")
if (log.isTraceEnabled)
.foreach { case (partSize, partId) => logTrace(s"Input partition $partId size $partSize in bytes") }
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
minNumPartitions = minPartitionNum
logInfo(s"Number of shuffle stages to coalesce ${validMetrics.length}")
s"Reduce number of partitions from ${validMetrics.head.bytesByPartitionId.length} to ${partitionSpecs.size}"
if (log.isDebugEnabled) {
val outputPartitionStatistics = partitionSpecs
.map { case (s: CoalescedPartitionSpec, id) =>
val newPartitionSize =
(s.startReducerIndex until s.endReducerIndex)
(id, s.startReducerIndex -> s.endReducerIndex, newPartitionSize)
logDebug(s"Output partition maxsize :${outputPartitionStatistics.maxBy { case (_, _, size) => size }}")
logDebug(s"Output partition min size :${outputPartitionStatistics.minBy { case (_, _, size) => size }}")
if (log.isTraceEnabled) {
outputPartitionStatistics.foreach { case (id, (startPartId, endPartId), size) =>
s"Output partition id:$id include input partitions: range from $startPartId to $endPartId, " +
s"number ${endPartId - startPartId}, size $size."
// This transformation adds new nodes, so we must use `transformUp` here.
val stageIds =
plan.transformUp {
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
case stage: ShuffleQueryStageExec if stageIds.contains( =>
logDebug(s"Apply custom coalesce to $stage")
CustomShuffleReaderExec(stage, partitionSpecs, "custom_coalesced")
} else {
val sparkConf = new SparkConf()
implicit val spark: SparkSession = SparkSession
