Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save GrigorievNick/2f77b26719e46c544e3f20aa48862719 to your computer and use it in GitHub Desktop.
Save GrigorievNick/2f77b26719e46c544e3f20aa48862719 to your computer and use it in GitHub Desktop.
Spark custom Physical plan optimisation to AQE force post shuffle coalesce for repartitionByRange with user specified partitions num.
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.CoalescedPartitionSpec
import org.apache.spark.sql.internal.SQLConf
case class ForceRepartitionByRangeCoalescePartitions(session: SparkSession) extends Rule[SparkPlan] {
private def conf = session.sessionState.conf
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.coalesceShufflePartitionsEnabled) {
return plan
}
if (
!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])
|| plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined
) {
// If not all leaf nodes are query stages, it's not safe to reduce the number of
// shuffle partitions, because we may break the assumption that all children of a spark plan
// have same number of output partitions.
return plan
}
def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
case stage @ ShuffleQueryStageExec(_, ShuffleExchangeExec(RangePartitioning(_, _), _, _)) => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
}
// shuffleStages.map(_.shuffle.transform())
val shuffleStages = collectShuffleStages(plan)
// Apply to specific ShuffleExchanges introduced by repartitionByRange.
if (shuffleStages.isEmpty) {
plan
} else {
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
// we should skip it when calculating the `partitionStartIndices`.
val validMetrics = shuffleStages.flatMap(_.mapStats)
// We may have different pre-shuffle partition numbers, don't reduce shuffle partition number
// in that case. For example when we union fully aggregated data (data is arranged to a single
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
if (log.isDebugEnabled) {
validMetrics.foreach { statistics =>
logDebug(s"Input shuffleId:${statistics.shuffleId} partitions:${statistics.bytesByPartitionId.length}")
logDebug(s"Max partition size :${statistics.bytesByPartitionId.max}")
logDebug(s"Min partition size :${statistics.bytesByPartitionId.min}")
if (log.isTraceEnabled)
statistics
.bytesByPartitionId
.zipWithIndex
.foreach { case (partSize, partId) => logTrace(s"Input partition $partId size $partSize in bytes") }
}
}
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
.getOrElse(session.sparkContext.defaultParallelism)
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
validMetrics.toArray,
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
minNumPartitions = minPartitionNum
)
logInfo(s"Number of shuffle stages to coalesce ${validMetrics.length}")
logInfo(
s"Reduce number of partitions from ${validMetrics.head.bytesByPartitionId.length} to ${partitionSpecs.size}"
)
if (log.isDebugEnabled) {
val outputPartitionStatistics = partitionSpecs
.zipWithIndex
.map { case (s: CoalescedPartitionSpec, id) =>
val newPartitionSize =
(s.startReducerIndex until s.endReducerIndex)
.map(validMetrics.head.bytesByPartitionId)
.sum
(id, s.startReducerIndex -> s.endReducerIndex, newPartitionSize)
}
logDebug(s"Output partition maxsize :${outputPartitionStatistics.maxBy { case (_, _, size) => size }}")
logDebug(s"Output partition min size :${outputPartitionStatistics.minBy { case (_, _, size) => size }}")
if (log.isTraceEnabled) {
outputPartitionStatistics.foreach { case (id, (startPartId, endPartId), size) =>
logTrace(
s"Output partition id:$id include input partitions: range from $startPartId to $endPartId, " +
s"number ${endPartId - startPartId}, size $size."
)
}
}
}
// This transformation adds new nodes, so we must use `transformUp` here.
val stageIds = shuffleStages.map(_.id).toSet
plan.transformUp {
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
logDebug(s"Apply custom coalesce to $stage")
CustomShuffleReaderExec(stage, partitionSpecs, "custom_coalesced")
}
} else {
plan
}
}
}
}
val sparkConf = new SparkConf()
implicit val spark: SparkSession = SparkSession
.builder()
.config(sparkConf)
.withExtensions(_.injectQueryStagePrepRule(ForceRepartitionByRangeCoalescePartitions))
.getOrCreate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment