Batching of RDDs. Allows to split into batches of tasks and evaluate single RDD in multiple stages instead of scheduling all tasks, main reason is overcoming OOMs when task requires a lot of memory to run, e.g. training a model
import org.apache.spark.rdd.batch.implicits._
val rdd = sc.parallelize(0 until 1000, 100)
val res = rdd.batch(numPartitionsPerBatch = 20)
val rdd = sc.parallelize(Seq("a", "b", "c", "d", "e", "f", "g", "h"), 10)
val res = rdd.batch(numPartitionsPerBatch = 4)
val rdd = sc.parallelize(Seq("a", "b", "c", "d", "e", "f", "g", "h"), 10).repartition(20)
val res = rdd.batch(numPartitionsPerBatch = 4)
val rdd = sc.parallelize(Seq((1, true), (2, false), (3, true)), 8)
val res = rdd.batch(numPartitionsPerBatch = 3)
package org.apache.spark.rdd.batch
import scala.collection.mutable.{ArrayBuffer, HashMap => MutableMap}
import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.shuffle._
import org.apache.spark.rdd.RDD
// RDD batching is introduced for reason of overcoming OOMs when scheduling all tasks for RDD having
// limited amount of memory per executor. This is a perfect fit for training model to still leverage
// Spark parallelism, but avoid collecting on a driver; this is a tradeoff of computation time and
// memory usage per executor and driver.
// Batching of tasks in Spark works as batches mapped to multiple stages. Batch resolution is left
// outside of RDD. Batching is split into two parts: 1 to N map stages and single reduce stage. Map
// stage modifies values to add partition index that is used as key-type for shuffle, therefore we
// serialize data per partition. Each map stage has depedency on itself, but does not pull data from
// shuffle reader. This allows to block execution of other partitions until ith batch is finished.
// Once all map stages are complete, reduce stage is launched to collect all shuffle output and
// remove partition index from values.
* [[BatchPartition]] is a mirror for parent partition of original RDD, and keeps track of partition
* task, so we can reconstruct partitions on reduce stage.
class BatchPartition(
val rddId: Long,
val slice: Int,
val parent: Partition)
extends Partition with Serializable {
override def hashCode: Int = (41 * (41 + rddId) + slice).toInt
override def equals(other: Any): Boolean = other match {
case that: BatchPartition => this.rddId == that.rddId && this.slice == that.slice
case _ => false
override def index: Int = slice
override def toString: String = {
s"${getClass.getSimpleName}(rddId=$rddId, index=$index, parent=$parent)"
* [[BatchMapRDD]] is a map-side RDD to serialize partition data, and wait for all parent RDDs to
* finish. Note that type 'T' must be serializable.
* @param parent original RDD to batch, not transient because it is used in compute method
* @param previous previous batch that this RDD needs to wait before starting computation
* @param batch set/batch of original RDD partitions that this RDD needs to evaluate
class BatchMapRDD[T: ClassTag](
var parent: RDD[T],
@transient var previous: Option[BatchMapRDD[T]],
private val batch: Array[Partition])
extends RDD[(Int, T)](parent.sparkContext, Nil) {
// Hash partitioner to spill data, this will create shuffle output per original partition
// since partitioner uses mod(key, len) function
val part = new HashPartitioner(parent.partitions.length)
// We need to remap original partitions, since Spark checks for valid partition indices, e.g.
// should start from 0 and cover all splits.
override def getPartitions: Array[Partition] = { { case (x, index) =>
new BatchPartition(, index, x)
* Get entire graph of depedencies for this batch RDD, for example
* original <- RDD1 <- RDD2 <- RDD3 will result in Seq(RDD1, RDD2) for RDD3 and empty sequence
* for RDD1.
def getBatchDependencies: Seq[BatchMapRDD[T]] = previous match {
case Some(batchRdd) => batchRdd.getBatchDependencies ++ Seq(batchRdd)
case None => Seq.empty
// Extract depedencies of map-side RDD, if this is the first batch, it maps directly to parent,
// otherwise everything is shuffle dependency on its 'previous' RDD. It is important to reuse
// the same partitioner for each shuffle depedency.
override def getDependencies: Seq[Dependency[_]] = previous match {
case Some(batchRdd) =>
// do not enable sort or map-side aggregation
// in Spark 1.x serializer is passed as Option, but in Spark 2.x it is just passed directly,
// when migrating to Spark 2.x just remove Some() wrapper
new ShuffleDependency[Int, T, T](batchRdd, part,
SparkEnv.get.serializer, None, None, false) :: Nil
case None =>
new OneToOneDependency[T](parent) :: Nil
// Compute method maps each value as (splitIndex, originalValue), where splitIndex is index of
// parent/original RDD partition, and originalValue is a value from original RDD for that
// partition. This will allow to reduce output correctly into the same number of partitions
override def compute(split: Partition, context: TaskContext): Iterator[(Int, T)] = {
val partition = split.asInstanceOf[BatchPartition]
val splitIndex = partition.parent.index
val iter = parent.iterator(partition.parent, context) { x => (splitIndex, x) }
override def clearDependencies(): Unit = {
parent = null
previous = null
* [[BatchReduceRDD]] reduces each output from [[BatchMapRDD]] and returns RDD that has original
* number of partitions and similar data distribution, meaning it is safe to rely on the same order
* of data in each partition.
* @param rdd RDD of last map-side stage in chain of batches
class BatchReduceRDD[T: ClassTag](
@transient var rdd: BatchMapRDD[T])
extends RDD[T](rdd) {
// Index to map original partitions to shuffle dependency that points to shuffle output for that
// partition.
private val shuffleSplitIndex: Map[Int, Dependency[_]] = buildShuffleDependencies()
override def getPartitions: Array[Partition] = rdd.parent.partitions
override def getDependencies: Seq[Dependency[_]] = {
// same issue with serializer in Spark 1.x, when migrating to Spark 2.x remove Some() wrapper
new ShuffleDependency[Int, T, T](rdd, rdd.part, SparkEnv.get.serializer,
None, None, false) :: Nil
private def buildShuffleDependencies(): Map[Int, Dependency[_]] = {
// build index of partition index and batch rdd
// all RDDs that this reduce step depends on
val rddDependencies = this.rdd.getBatchDependencies :+ this.rdd
// index of original partition index to map-side RDD
val batchPartitionIndex = new MutableMap[Int, RDD[_]]()
rddDependencies.foreach { rdd =>
rdd.getPartitions.foreach {
case bpart: BatchPartition =>
// batch partition index is not unique, but original partition index is, here we also do
// some sanity check to make sure that there are no two or more batch partitions that
// compute the same original partition
if (batchPartitionIndex.contains(bpart.parent.index)) {
throw new IllegalStateException(s"Map-side RDD ${bpart.rddId} contains duplicate " +
s"partition ${bpart.parent} (${bpart.parent.index}) that maps to $bpart. This " +
"implies that batch map was evaluated more than once for original partition")
batchPartitionIndex.put(bpart.parent.index, rdd)
case other =>
sys.error(s"Unexpected partition $other found that is not batch partition")
// build index for shuffle dependencies for RDD deps
// we also need to add this reduce-side dependency to build full map
val shuffleIndex = new MutableMap[RDD[_], Dependency[_]]()
(rddDependencies :+ this).foreach { rdd =>
rdd.dependencies.foreach {
case shuffleDep: ShuffleDependency[_, _, _] =>
shuffleIndex.put(shuffleDep.rdd, shuffleDep)
case otherDep => // no-op for one-to-one or range dependencies
// merge two data structures
val partitionMap: Map[Int, Dependency[_]] = { case (partIndex, rdd) =>
// extract shuffle dependency associated with map-side RDD, this is different than looking up
// dependencies for that RDD, since we are looking for shuffle that has that rdd as dependency
val shuffleDep = shuffleIndex.getOrElse(rdd, sys.error(s"Failed to find shuffle for " +
s"rdd $rdd ${} when resolving partition index $partIndex"))
(partIndex, shuffleDep)
// check that we covered all original partitions
if (partitionMap.keys.size != this.getPartitions.length) {
throw new IllegalStateException(
s"Partition-dependency map has ${partitionMap.keys.size} partitions, but RDD should " +
s"have ${this.getPartitions.length} partitions; map = $partitionMap")
override protected def getPreferredLocations(partition: Partition): Seq[String] = {
val locatedDependency = shuffleSplitIndex.getOrElse(partition.index,
sys.error(s"Failed to locate shuffle dependency for $partition (${partition.index})"))
val locatedShuffleDependency = locatedDependency.asInstanceOf[ShuffleDependency[_, _, _]]
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
tracker.getPreferredLocationsForShuffle(locatedShuffleDependency, partition.index)
// Each dependency is assumed to shuffle dependency
private def iteratorForDependency(
dep: Dependency[_], partition: Partition, context: TaskContext): Iterator[_] = {
getReader(dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle,
partition.index, partition.index + 1, context).read()
// Read shuffle output and remap each value to remove partition index
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val shuffleDep = shuffleSplitIndex.getOrElse(split.index,
sys.error(s"Failed to locate rdd for partition $split (${split.index})"))
var iter: Iterator[T] = Iterator.empty
// join all iterators for dependencies
for (dependency <- dependencies) {
iter = iter ++ iteratorForDependency(shuffleDep, split, context).
asInstanceOf[Iterator[(Int, T)]].map { case (index, value) => value }
override def clearDependencies(): Unit = {
rdd = null
package org.apache.spark.rdd.batch
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.batch._
package object implicits {
implicit class BatchedRDDFunctions[T : ClassTag](rdd: RDD[T]) {
def batch(numPartitionsPerBatch: Int): RDD[T] = {
require(numPartitionsPerBatch > 0,
s"Positive number of partitions per batch is required, found $numPartitionsPerBatch")
// if requested num partitions is greater than or equal to total number of RDD partitions
// we just return RDD itself, since it would result in single batch; empty RDD will fall into
// this condition as well
if (rdd.partitions.length <= numPartitionsPerBatch) {
} else {
// total batches generated, do not use sc.defaultParallelism
val batches = rdd.partitions.sliding(numPartitionsPerBatch, numPartitionsPerBatch)
// build RDD operations graph, looks like this:
// reduce -> map -> map -> map -> map -> mapPartitions -> RDD
var mapRdd: Option[BatchMapRDD[T]] = None
for (batch <- batches) {
mapRdd = Some(new BatchMapRDD[T](rdd, mapRdd, batch))
mapRdd match {
case Some(mapPart) =>
new BatchReduceRDD(mapPart)
case None =>
throw new IllegalStateException(
"No batches generated for map-side RDD using " +
s"$numPartitionsPerBatch partitions per batch")
import org.apache.spark.HashPartitioner
import org.apache.spark.rdd.ParallelCollectionRDD
import org.apache.spark.rdd.batch.{BatchMapRDD, BatchReduceRDD}
class BatchMapReduceSuite extends UnitTestSpec with SparkLocalMode with BeforeAndAfter {
before {
after {
test("invalid batch size <= 0") {
val rdd = sc.parallelize(0 until 10, 8)
var err = intercept[IllegalArgumentException] { rdd.batch(-1) }
assert(err.getMessage.contains("Positive number of partitions per batch is required"))
err = intercept[IllegalArgumentException] { rdd.batch(0) }
assert(err.getMessage.contains("Positive number of partitions per batch is required"))
test("return original RDD, if it is empty") {
val rdd = sc.emptyRDD[Int]
val res = rdd.batch(10)
res should be (rdd)
test("return original RDD, if it has fewer partitions than batch size") {
val rdd = sc.parallelize(0 until 10, 8)
val res1 = rdd.batch(10)
res1 should be (rdd)
val res2 = rdd.batch(8)
res2 should be (rdd)
test("return batch reduce RDD, if number of partitions is larger than batch size") {
val rdd = sc.parallelize(0 until 10, 8)
val res1 = rdd.batch(4)
res1.isInstanceOf[BatchReduceRDD[_]] should be (true)
// we do allow to use batch size of 1 - one partition per stage
val res2 = rdd.batch(1)
res2.isInstanceOf[BatchReduceRDD[_]] should be (true)
test("batch map RDD should have same hash partitioner of original partitions") {
val rdd = sc.parallelize(0 until 10, 8)
val batch1 = new BatchMapRDD(rdd, None, rdd.partitions)
batch1.part.isInstanceOf[HashPartitioner] should be (true)
val batch2 = new BatchMapRDD(rdd, Some(batch1), rdd.partitions)
batch2.part should be (batch1.part)
test("batch map RDD should return correct number of dependencies") {
val rdd = sc.parallelize(0 until 10, 8)
val batch1 = new BatchMapRDD(rdd, None, rdd.partitions)
val batch2 = new BatchMapRDD(rdd, Some(batch1), rdd.partitions)
val batch3 = new BatchMapRDD(rdd, Some(batch2), rdd.partitions)
batch1.getBatchDependencies should be (Seq.empty)
batch2.getBatchDependencies should be (Seq(batch1))
batch3.getBatchDependencies should be (Seq(batch1, batch2))
test("batch reduce RDD should fail if original partitions are reused") {
val rdd = sc.parallelize(0 until 10, 8)
val batch1 = new BatchMapRDD(rdd, None, rdd.partitions)
val batch2 = new BatchMapRDD(rdd, Some(batch1), rdd.partitions)
val err = intercept[IllegalStateException] {
new BatchReduceRDD(batch2)
assert(err.getMessage.contains(s"Map-side RDD ${} contains duplicate partition"))
test("batch reduce RDD should fail if original partitions are different from partition map") {
val rdd = sc.parallelize(0 until 10, 8)
val batch1 = new BatchMapRDD(rdd, None, rdd.partitions.take(3))
val batch2 = new BatchMapRDD(rdd, Some(batch1), rdd.partitions.drop(4))
val err = intercept[IllegalStateException] {
new BatchReduceRDD(batch2)
"Partition-dependency map has 7 partitions, but RDD should have 8 partitions"))
test("batch reduce RDD should return same partitions as original RDD") {
val rdd = sc.parallelize(0 until 10, 8)
val batch1 = new BatchMapRDD(rdd, None, rdd.partitions.take(3))
val batch2 = new BatchMapRDD(rdd, Some(batch1), rdd.partitions.drop(3))
val res = new BatchReduceRDD(batch2)
res.partitions should be (rdd.partitions)
test("compute correctness test - no-parent int RDD") {
val rdd = sc.parallelize(0 until 10, 10)
val res = rdd.batch(3)
res.glom.collect should be (rdd.glom.collect)
test("compute correctness test - no-parent char RDD") {
val rdd = sc.parallelize(Seq("a", "b", "c", "d", "e", "f", "g", "h"), 10)
val res = rdd.batch(4)
res.glom.collect should be (rdd.glom.collect)
test("compute correctness test - no-parent complex type RDD") {
val rdd = sc.parallelize(
Seq((1, true), (2, false), (3, true), (4, false), (5, true), (6, false)), 8)
val res = rdd.batch(3)
res.glom.collect should be (rdd.glom.collect)
test("compute correctness test - one-parent int RDD") {
val rdd = sc.parallelize(0 until 100, 10).map { x => x * x }.filter { _ % 2 == 0 }
val res = rdd.batch(7)
res.glom.collect should be (rdd.glom.collect)
test("compute correctness test - shuffle-parent int RDD") {
val rdd = sc.parallelize(0 until 100, 10).repartition(20)
val res = rdd.batch(7)
res.glom.collect should be (rdd.glom.collect)
