Skip to content

Instantly share code, notes, and snippets.

@shlomiv
Created May 19, 2016 20:49
Show Gist options
  • Save shlomiv/3e5f832e4a6cb320e50b67dd05b3e97c to your computer and use it in GitHub Desktop.
Save shlomiv/3e5f832e4a6cb320e50b67dd05b3e97c to your computer and use it in GitHub Desktop.
// Split an rdd according to its partition number
def splitByPartition[T:ClassTag](rdd: RDD[T], hotPartitions:Int): (RDD[T], RDD[T]) = {
val splits = rdd.mapPartitions { iter =>
val partId = TaskContext.get.partitionId
val left = if (partId < hotPartitions) iter else empty
val right = if (partId >= hotPartitions) iter else empty
Seq(left, right).iterator
}
val left = splits.mapPartitions { iter => iter.next().toIterator}
val right = splits.mapPartitions { iter =>
iter.next()
iter.next().toIterator
}
(left, right)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment