Skip to content

Instantly share code, notes, and snippets.

@jroper
Last active March 12, 2018 14:22
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jroper/f28ecf79f4a4be70e3f499a672d8d6b5 to your computer and use it in GitHub Desktop.
Save jroper/f28ecf79f4a4be70e3f499a672d8d6b5 to your computer and use it in GitHub Desktop.
Akka streams Source.restartWithBackoff
package streams.utils
import java.util.concurrent.ThreadLocalRandom
import akka.NotUsed
import akka.stream._
import akka.stream.scaladsl.Source
import akka.stream.stage._
import scala.concurrent.duration._
object SourceWithBackoffSupervision {
implicit class EnrichedSupervisableSource[T](source: Source[T, _]) {
/**
* Restart the source with the given backoff parameters when it completes or fails.
*
* @param minBackoff The minimum backoff.
* @param maxBackoff The maximum backoff.
* @param randomFactor A random factor.
*/
def restartWithBackoff(
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
randomFactor: Double
): Source[T, NotUsed] = {
Source.fromGraph(new RestartWithBackoff[T](minBackoff, maxBackoff, randomFactor, source))
}
}
private final class RestartWithBackoff[T](
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
randomFactor: Double,
thisSource: Graph[SourceShape[T], _]
) extends GraphStage[SourceShape[T]] {
private val out = Outlet[T]("RestartWithBackoff.out")
override def shape = SourceShape(out)
override def initialAttributes = Attributes.name("RestartWithBackoff")
override def createLogic(attr: Attributes) = new TimerGraphStageLogicWithLogging(shape) {
var restartCount = 0
var resetDeadline = minBackoff.fromNow
def startSource(): Unit = {
val sinkIn = new SubSinkInlet[T]("RestartWithBackoffSink")
sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish() = {
log.debug("Source finished")
onCompleteOrFailure()
}
override def onUpstreamFailure(ex: Throwable) = {
log.error(ex, "Restarting source due to failure")
onCompleteOrFailure()
}
})
setHandler(out, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish() = sinkIn.cancel()
})
Source.fromGraph(thisSource).runWith(sinkIn.sink)(subFusingMaterializer)
if (isAvailable(out)) sinkIn.pull()
}
def backoff(): Unit = {
setHandler(out, new OutHandler {
override def onPull() = ()
})
}
def onCompleteOrFailure()= {
if (resetDeadline.isOverdue()) {
restartCount = 0
}
val restartDelay = calculateDelay(restartCount, minBackoff, maxBackoff, randomFactor)
log.debug("Restarting stream in {}", restartDelay)
scheduleOnce("RestartTimer", restartDelay)
restartCount += 1
backoff()
}
override protected def onTimer(timerKey: Any) = {
startSource()
resetDeadline = minBackoff.fromNow
}
setHandler(out, new OutHandler {
override def onPull() = startSource()
})
}
override def toString: String = "RestartWithBackoff"
}
/**
* Copied from akka.pattern.BackoffSupervisor.
*/
private def calculateDelay(
restartCount: Int,
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
randomFactor: Double): FiniteDuration = {
val rnd = 1.0 + ThreadLocalRandom.current().nextDouble() * randomFactor
if (restartCount >= 30) // Duration overflow protection (> 100 years)
maxBackoff
else
maxBackoff.min(minBackoff * math.pow(2, restartCount)) * rnd match {
case f: FiniteDuration ⇒ f
case _ ⇒ maxBackoff
}
}
}
package stream.utils
import akka.actor.ActorSystem
import akka.stream.{ActorMaterializer, Materializer}
import akka.stream.scaladsl.Source
import akka.stream.testkit.scaladsl.TestSink
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec}
import scala.concurrent.duration._
class SourceWithBackoffSupervisionSpec extends WordSpec with Matchers with BeforeAndAfterAll {
import streams.utils.SourceWithBackoffSupervision._
implicit var system: ActorSystem = _
implicit var materializer: Materializer = _
"RestartWithBackoff source" should {
"run normally" in {
val probe = Source.repeat("a")
.restartWithBackoff(500.millis, 1.seconds, 0)
.runWith(TestSink.probe)
probe.requestNext("a")
probe.requestNext("a")
probe.requestNext("a")
probe.requestNext("a")
probe.requestNext("a")
probe.cancel()
}
"restart on completion" in {
val probe = Source(List("a", "b"))
.restartWithBackoff(10.millis, 100.millis, 0)
.runWith(TestSink.probe)
probe.requestNext("a")
probe.requestNext("b")
probe.requestNext("a")
probe.requestNext("b")
probe.requestNext("a")
probe.requestNext("b")
probe.cancel()
}
"restart on failure" in {
val probe = Source(List("a", "b", "c"))
.map {
case "c" => sys.error("failed")
case other => other
}
.restartWithBackoff(10.millis, 100.millis, 0)
.runWith(TestSink.probe)
probe.requestNext("a")
probe.requestNext("b")
probe.requestNext("a")
probe.requestNext("b")
probe.requestNext("a")
probe.requestNext("b")
probe.cancel()
}
"backoff before restart" in {
val probe = Source(List("a", "b"))
.restartWithBackoff(1.second, 2.seconds, 0)
.runWith(TestSink.probe)
probe.requestNext("a")
probe.requestNext("b")
probe.request(1)
probe.expectNoMsg(500.milliseconds)
probe.expectNext(1.second, "a")
probe.requestNext("b")
probe.cancel()
}
}
override protected def beforeAll() = {
system = ActorSystem("Test")
materializer = ActorMaterializer()
}
override protected def afterAll() = super.afterAll()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment