Last active
July 27, 2018 21:12
-
-
Save jchapuis/6b3ced8adb9fedcd0263864c628e8087 to your computer and use it in GitHub Desktop.
Akka stream ZipLatest GraphStage
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import akka.stream._ | |
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } | |
/** | |
* Zips two streams, picking always the latest of the elements of each source | |
* | |
* No element is emitted until at least one element of each becomes available. Whenever a new | |
* element appears, a new tuple is emitted with the last seen element of the other type | |
* | |
* '''Emits when''' all of the inputs have at least an element available, and then each time an element becomes | |
* available on either of the inputs | |
* | |
* '''Backpressures when''' downstream backpressures | |
* | |
* '''Completes when''' any of the upstreams completes | |
* | |
* '''Cancels when''' downstream cancels | |
*/ | |
class ZipLatest[A, B] extends GraphStage[FanInShape2[A, B, (A, B)]] { | |
val in0: Inlet[A] = Inlet[A]("ZipLatest.in1") | |
val in1: Inlet[B] = Inlet[B]("ZipLatest.in2") | |
val out: Outlet[(A, B)] = Outlet[(A, B)]("ZipLatest.out") | |
override val shape = new FanInShape2[A, B, (A, B)](in0, in1, out) | |
// scalastyle:off method.length | |
override def createLogic(attr: Attributes): GraphStageLogic = | |
new GraphStageLogic(shape) { | |
var lastA = Option.empty[A] | |
var lastB = Option.empty[B] | |
var lastPushed = Option.empty[(A, B)] | |
var waitingForPair = false | |
setHandler( | |
out, | |
new OutHandler { | |
override def onPull(): Unit = { | |
(lastA, lastB, lastPushed) match { | |
case (Some(a), Some(b), None) => pushPair(a, b) | |
case (Some(a), Some(b), Some((sentA, sentB))) if a != sentA || b != sentB => | |
pushPair(a, b) | |
case _ => waitingForPair = true | |
} | |
pullBoth | |
} | |
} | |
) | |
setHandler( | |
in0, | |
new InHandler { | |
override def onPush() = { | |
val newA = grab(in0) | |
pushIfWaitingAndPairReady(newA = Some(newA)) | |
lastA = Some(newA) | |
} | |
override def onUpstreamFinish(): Unit = super.onUpstreamFinish() | |
} | |
) | |
setHandler( | |
in1, | |
new InHandler { | |
override def onPush() = { | |
val newB = grab(in1) | |
pushIfWaitingAndPairReady(newB = Some(newB)) | |
lastB = Some(newB) | |
} | |
override def onUpstreamFinish(): Unit = super.onUpstreamFinish() | |
} | |
) | |
def pushPair(a: A, b: B): Unit = { | |
push(out, (a, b)) | |
lastPushed = Some((a, b)) | |
} | |
def pushIfWaitingAndPairReady(newA: Option[A] = None, newB: Option[B] = None): Unit = | |
if (waitingForPair) { | |
(newA.orElse(lastA), newB.orElse(lastB)) match { | |
case (Some(a), Some(b)) => | |
pushPair(a, b) | |
waitingForPair = false | |
pullBoth() | |
case _ => () | |
} | |
} | |
def pullBoth(): Unit = { | |
if (!hasBeenPulled(in0)) { | |
pull(in0) | |
} | |
if (!hasBeenPulled(in1)) { | |
pull(in1) | |
} | |
} | |
} | |
// scalastyle:on MethodLengthChecker | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import akka.actor.ActorSystem | |
import akka.stream.scaladsl.{ GraphDSL, RunnableGraph } | |
import akka.stream.testkit.TestPublisher.Probe | |
import akka.stream.testkit.scaladsl.{ TestSink, TestSource } | |
import akka.stream.{ ActorMaterializer, ClosedShape } | |
import akka.testkit.TestKit | |
import org.scalacheck.Gen | |
import org.scalatest.concurrent.ScalaFutures | |
import org.scalatest.prop.PropertyChecks | |
import org.scalatest.{ BeforeAndAfterAll, GivenWhenThen, Matchers, WordSpecLike } | |
import scala.concurrent.duration._ | |
class ZipLatestSpec | |
extends TestKit(ActorSystem("ZipLatestSpec")) | |
with WordSpecLike | |
with Matchers | |
with BeforeAndAfterAll | |
with PropertyChecks | |
with GivenWhenThen | |
with ScalaFutures { | |
implicit val materializer = ActorMaterializer() | |
override def afterAll = TestKit.shutdownActorSystem(system) | |
implicit val patience = PatienceConfig(10 seconds) | |
"ZipLatest" must { | |
"only emit when at least one pair is available" in { | |
val (probe, bools, ints) = testGraph[Boolean, Int] | |
Given("request for one element") | |
probe.request(1) | |
And("one element pushed on each source") | |
bools.sendNext(true) | |
ints.sendNext(1) | |
Then("emits a single pair") | |
probe.expectNext((true, 1)) | |
} | |
"does not emit the same pair upon two pulls" in { | |
val (probe, bools, ints) = testGraph[Boolean, Int] | |
Given("request for one element") | |
probe.request(1) | |
And("one element pushed on each source") | |
bools.sendNext(true) | |
ints.sendNext(1) | |
Then("emits a single pair") | |
probe.expectNext((true, 1)) | |
And("another request") | |
probe.request(1) | |
Then("does not emit a duplicate") | |
bools.sendComplete() | |
probe.expectComplete() | |
} | |
val first = (t: (Probe[Boolean], Probe[Int])) => t._1 | |
val second = (t: (Probe[Boolean], Probe[Int])) => t._2 | |
"complete when either source completes" in { | |
forAll(Gen.oneOf(first, second)) { select => | |
val (probe, bools, ints) = testGraph[Boolean, Int] | |
Given("either source completes") | |
select((bools, ints)).sendComplete() | |
Then("subscribes and completes") | |
probe.expectSubscriptionAndComplete() | |
} | |
} | |
"fail when either source has error" in { | |
forAll(Gen.oneOf(first, second)) { select => | |
val (probe, bools, ints) = testGraph[Boolean, Int] | |
val error = new RuntimeException | |
Given("either source errors") | |
select((bools, ints)).sendError(error) | |
Then("subscribes and error") | |
probe.expectSubscriptionAndError(error) | |
} | |
} | |
"emit even if pair is the same" in { | |
val (probe, bools, ints) = testGraph[Boolean, Int] | |
Given("request for two elements") | |
probe.request(2) | |
And("one element pushed on each source") | |
bools.sendNext(true) | |
ints.sendNext(1) | |
And("once again the same element on one source") | |
ints.sendNext(1) | |
Then("emits a two equal pairs") | |
probe.expectNext((true, 1)) | |
probe.expectNext((true, 1)) | |
} | |
"emit combined elements in proper order" in { | |
val (probe, firstDigits, secondDigits) = testGraph[Int, Int] | |
Given(s"numbers up to 99 in tuples") | |
val allNumbers = for { | |
firstDigit <- 0 to 9 | |
secondDigit <- 0 to 9 | |
} yield (firstDigit, secondDigit) | |
allNumbers.groupBy(_._1).toList.sortBy(_._1).foreach { | |
case (firstDigit, pairs) => { | |
When(s"sending first digit $firstDigit") | |
firstDigits.sendNext(firstDigit) | |
pairs.map { case (_, digits) => digits }.foreach { secondDigit => | |
And(s"sending second digit $secondDigit") | |
secondDigits.sendNext(secondDigit) | |
probe.request(1) | |
Then(s"should receive tuple ($firstDigit,$secondDigit)") | |
probe.expectNext((firstDigit, secondDigit)) | |
} | |
} | |
} | |
} | |
} | |
private def testGraph[A, B] = | |
RunnableGraph | |
.fromGraph( | |
GraphDSL | |
.create(TestSink.probe[(A, B)], TestSource.probe[A], TestSource.probe[B])(Tuple3.apply) { | |
implicit b => (ts, as, bs) => | |
import GraphDSL.Implicits._ | |
val zipLatest = b.add(new ZipLatest[A, B]()) | |
as ~> zipLatest.in0 | |
bs ~> zipLatest.in1 | |
zipLatest.out ~> ts | |
ClosedShape | |
} | |
) | |
.run() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment