Skip to content

Instantly share code, notes, and snippets.

@jroper
Created July 15, 2017 11:06
Show Gist options
  • Save jroper/4f8108f8fa00a7251919e08d4cf9eb71 to your computer and use it in GitHub Desktop.
Save jroper/4f8108f8fa00a7251919e08d4cf9eb71 to your computer and use it in GitHub Desktop.
Akka streams LazyBroadcastHub - a broadcast hub that only keeps its source materialized as long as there are consumers
package streams.utils
import akka.NotUsed
import akka.stream._
import akka.stream.scaladsl.{BroadcastHub, Keep, RunnableGraph, Source}
import akka.stream.stage._
import scala.concurrent.duration.{Duration, FiniteDuration}
/**
* Provides a broadcast hub that only runs the source when there are sinks connected to it.
*/
object LazyBroadcastHub {
/**
* Create a broadcast hub for the given source.
*
* The hub will only run the source when there are consumers attached to the hub. When all consumers disconnect,
* after the given idle timeout, if no more consumers connect it will shut the source down.
*
* The source will be rematerialized whenever it's not running but a new consumer attaches to the hub.
*
* The materialization value is a tuple of a source as produced by BroadcastHub, and a KillSwitch to kill the hub.
*
* @param source The source to broadcast.
* @param idleTimeout The time to wait when there are no consumers before shutting the source down.
* @param bufferSize The buffer size to buffer messages to producers.
*/
def forSource[T](source: Source[T, _], idleTimeout: FiniteDuration, bufferSize: Int): RunnableGraph[(Source[T, NotUsed], KillSwitch)] = {
Source.fromGraph(new LazySourceStage[T](source, idleTimeout))
.viaMat(KillSwitches.single)(Keep.both)
.toMat(BroadcastHub.sink[T](bufferSize)) {
case ((callbacks, killSwitch), broadcastSource) =>
val source = broadcastSource.via(new RecordingStage[T](callbacks))
(source, killSwitch)
}
}
def forSource[T](source: Source[T, _], idleTimeout: FiniteDuration): RunnableGraph[(Source[T, NotUsed], KillSwitch)] =
forSource(source, idleTimeout, bufferSize = 256)
def forSource[T](source: Source[T, _], bufferSize: Int): RunnableGraph[(Source[T, NotUsed], KillSwitch)] =
forSource(source, Duration.Zero, bufferSize)
def forSource[T](source: Source[T, _]): RunnableGraph[(Source[T, NotUsed], KillSwitch)] =
forSource(source, Duration.Zero)
private trait MaterializationCallbacks {
def materialized(): Unit
def completed(): Unit
}
private class RecordingStage[T](callbacks: MaterializationCallbacks) extends GraphStage[FlowShape[T, T]] {
private val in = Inlet[T]("RecordingStage.in")
private val out = Outlet[T]("RecordingStage.out")
override def shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) {
setHandler(in, new InHandler {
override def onPush() = push(out, grab(in))
})
setHandler(out, new OutHandler {
override def onPull() = pull(in)
override def onDownstreamFinish() = callbacks.completed()
})
override def preStart() = {
// This must be done in preStart, if done during materialization then there's a race for the LazySourceStage
// to finish materializing before this gets invoked.
callbacks.materialized()
}
}
}
private class LazySourceStage[T](source: Source[T, _], idleTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[T], MaterializationCallbacks] {
private val out = Outlet[T]("LazySourceStage.out")
override def shape = SourceShape(out)
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val logic = new TimerGraphStageLogic(shape) with MaterializationCallbacks {
var materializedSources = 0
var activeIn: Option[SubSinkInlet[T]] = None
var stopSourceRequest = 0
val materializedCallback = createAsyncCallback[Unit] { _ =>
materializedSources += 1
if (activeIn.isEmpty) {
startSource()
}
}
val completedCallback = createAsyncCallback[Unit] { _ =>
materializedSources -= 1
if (materializedSources == 0) {
if (idleTimeout == Duration.Zero) {
stopSource()
} else {
stopSourceRequest += 1
scheduleOnce(stopSourceRequest, idleTimeout)
}
}
}
def startSource() = {
assert(activeIn.isEmpty)
val in = new SubSinkInlet[T]("LazySourceStage.in")
in.setHandler(new InHandler {
override def onPush() = push(out, in.grab())
})
setHandler(out, new OutHandler {
override def onPull() = in.pull()
})
source.runWith(in.sink)(subFusingMaterializer)
if (isAvailable(out)) {
in.pull()
}
activeIn = Some(in)
}
def stopSource() = {
assert(activeIn.nonEmpty)
activeIn.get.cancel()
ignoreOut()
activeIn = None
}
override protected def onTimer(timerKey: Any) = {
if (stopSourceRequest == timerKey && materializedSources == 0) {
stopSource()
}
}
def ignoreOut() = {
setHandler(out, new OutHandler {
override def onPull() = ()
})
}
override def materialized() = {
materializedCallback.invoke(())
}
override def completed() = {
completedCallback.invoke(())
}
ignoreOut()
}
(logic, logic)
}
}
}
package streams.utils
import java.util.concurrent.atomic.AtomicBoolean
import akka.Done
import akka.actor.ActorSystem
import akka.stream.scaladsl.{Sink, Source}
import akka.stream.testkit.javadsl.TestSink
import akka.stream.{ActorMaterializer, Materializer}
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec}
import play.api.{Environment, LoggerConfigurator}
import streams.utils.LazyBroadcastHub
import scala.concurrent.{Await, Promise}
import scala.concurrent.duration._
class LazyBroadcastHubSpec extends WordSpec with Matchers with BeforeAndAfterAll {
implicit var system: ActorSystem = _
implicit var materializer: Materializer = _
implicit def dispatcher = system.dispatcher
val environment = Environment.simple()
LoggerConfigurator(environment.classLoader).foreach(_.configure(environment))
"LazyBroadcastHub" should {
"not start the source if there are no consumers" in {
val materialized = new AtomicBoolean()
LazyBroadcastHub.forSource(Source.empty.mapMaterializedValue(_ => materialized.set(true))).run()
Thread.sleep(200)
materialized.get() should be (false)
}
"start the source when a consumer attaches" in {
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a")).run()
val sink = source.runWith(TestSink.probe(system))
sink.requestNext("a")
}
"shut down the source when a single consumer disconnects" in {
val shutdown = Promise[Done]()
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) =>
shutdown.completeWith(term)
}).run()
source.runWith(Sink.head)
Await.ready(shutdown.future, 10.seconds)
}
"not shutdown when there is still a consumer" in {
val shutdown = Promise[Done]()
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) =>
shutdown.completeWith(term)
}).run()
val sink1 = source.runWith(TestSink.probe(system))
val sink2 = source.runWith(TestSink.probe(system))
sink1.requestNext("a")
sink2.requestNext("a")
sink2.cancel()
Thread.sleep(200)
shutdown.isCompleted should be (false)
}
"shut down when multiple consumers disconnect" in {
val shutdown = Promise[Done]()
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) =>
shutdown.completeWith(term)
}).run()
val sink1 = source.runWith(TestSink.probe(system))
val sink2 = source.runWith(TestSink.probe(system))
sink1.requestNext("a")
sink2.requestNext("a")
sink1.cancel()
sink2.cancel()
Await.ready(shutdown.future, 10.seconds)
}
"wait until a timeout before disconnecting" in {
val shutdown = Promise[Done]()
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) =>
shutdown.completeWith(term)
}, 300.millis).run()
source.runWith(Sink.head)
Thread.sleep(200)
shutdown.isCompleted should be (false)
Await.ready(shutdown.future, 10.seconds)
}
"not disconnect if a new sink connects within the timeout" in {
val shutdown = Promise[Done]()
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) =>
shutdown.completeWith(term)
}, 300.millis).run()
source.runWith(Sink.head)
Thread.sleep(200)
val sink = source.runWith(TestSink.probe(system))
sink.requestNext("a")
Thread.sleep(200)
shutdown.isCompleted should be (false)
}
}
override protected def beforeAll() = {
system = ActorSystem("Test")
materializer = ActorMaterializer()
}
override protected def afterAll() = {
system.terminate()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment