Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
object Stage {
/**
* A source which feeds an initial batch into a provided flow and emits when entire batch is completed while retrying
* unprocessed elements with a backoff strategy.
*
* The provided flow must accept a set of elements and emit a subset containing unprocessed items or an empty set if
* all elements are processed successfully
*
* @param flow a flow which processes sets at a time and emits corresponding sets containing unprocessed items (if any)
* @param minBackoff min time to back off when retrying unprocessed elements
* @param maxBackoff max time to back off when retrying unprocessed elements
* @param randomFactor
* @param batch initial batch (set) of elements to process
* @param logging
* @tparam A
* @return
*/
def retryBatchSource[A](flow: Flow[Set[A], Set[A], NotUsed])(
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
maxRetries: Int = UnlimitedRetries,
randomFactor: Double = DefaultRandomFactor
)(batch: Set[A])(implicit logging: LoggingAdapter): Source[Set[A], NotUsed] = {
val waitUntilCompletion: Flow[Set[A], Set[A], NotUsed] = Flow[Set[A]].fold(0)((acc, _) => acc).map(_ => batch)
val backoffDelay = Flow[Set[A]].takeWhile(_.nonEmpty).zipWithIndex.flatMapConcat {
case (l, i) =>
val retryCount = i.toInt + 1
if (maxRetries > 0 && retryCount > maxRetries) {
Source.failed(new Exception(s"Exceeded retry limit for batch retry. Failing."))
} else {
val delay = calculateDelay(retryCount, minBackoff, maxBackoff, randomFactor)
logging.warning(
s"Received unprocessed items. Retrying # $retryCount after delay: ${delay.length} ${delay.unit.name()}"
)
Source.single(l).delay(delay, OverflowStrategy.backpressure)
}
}
Source.fromGraph(GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._
val completedBatch = b.add(Flow[Set[A]])
val merge = b.add(Merge[Set[A]](2))
val bcast = b.add(Broadcast[Set[A]](2))
Source.single(batch) ~> merge ~> flow.async ~> bcast
bcast.out(1) ~> waitUntilCompletion ~> completedBatch
bcast.out(0) ~> backoffDelay ~> merge
SourceShape(completedBatch.out)
})
}
/**
* Calculates an exponential backoff time given a restart count, min/max backoff combined with a random factor
* Note: this logic is the same as what is used by internal Akka backoff calculations
* @param restartCount
* @param minBackoff
* @param maxBackoff
* @param randomFactor
* @return
*/
private def calculateDelay(restartCount: Int,
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
randomFactor: Double): FiniteDuration = {
val rnd = 1.0 + ThreadLocalRandom.current().nextDouble() * randomFactor
val calculatedDuration = Try(maxBackoff.min(minBackoff * math.pow(2, restartCount)) * rnd).getOrElse(maxBackoff)
calculatedDuration match {
case f: FiniteDuration ⇒ f
case _ ⇒ maxBackoff
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment