Skip to content

Instantly share code, notes, and snippets.

@searler
Created June 21, 2015 19:47
Show Gist options
  • Save searler/124ba55566e4cbc3c0cd to your computer and use it in GitHub Desktop.
Save searler/124ba55566e4cbc3c0cd to your computer and use it in GitHub Desktop.
Akka reactive streams Bidi example from documentation
/**
* Extracted from the Bidi example at http://doc.akka.io/docs/akka-stream-and-http-experimental/1.0-RC3/scala/stream-graphs.html
*/
package bidi
import akka.util.ByteString
import akka.stream.scaladsl.BidiFlow
import akka.stream.scaladsl.Flow
import akka.stream.BidiShape
import java.nio.ByteOrder
import akka.stream.stage.Context
import akka.stream.stage.PushPullStage
import akka.stream.stage.SyncDirective
import akka.stream.scaladsl.Source
import akka.stream.scaladsl.Sink
import scala.concurrent.Await
import akka.stream.ActorFlowMaterializer
import akka.actor.ActorSystem
object ProtocolStacker extends App {
trait Message
case class Ping(id: Int) extends Message
case class Pong(id: Int) extends Message
def toBytes(msg: Message): ByteString = {
implicit val order = ByteOrder.LITTLE_ENDIAN
msg match {
case Ping(id) => ByteString.newBuilder.putByte(1).putInt(id).result()
case Pong(id) => ByteString.newBuilder.putByte(2).putInt(id).result()
}
}
def fromBytes(bytes: ByteString): Message = {
implicit val order = ByteOrder.LITTLE_ENDIAN
val it = bytes.iterator
it.getByte match {
case 1 => Ping(it.getInt)
case 2 => Pong(it.getInt)
case other => throw new RuntimeException(s"parse error: expected 1|2 got $other")
}
}
val codecVerbose = BidiFlow() { b =>
// construct and add the top flow, going outbound
val outbound = b.add(Flow[Message].map(toBytes))
// construct and add the bottom flow, going inbound
val inbound = b.add(Flow[ByteString].map(fromBytes))
// fuse them together into a BidiShape
BidiShape(outbound, inbound)
}
// this is the same as the above
val codec = BidiFlow(toBytes _, fromBytes _)
val framing = BidiFlow() { b =>
implicit val order = ByteOrder.LITTLE_ENDIAN
def addLengthHeader(bytes: ByteString) = {
val len = bytes.length
ByteString.newBuilder.putInt(len).append(bytes).result()
}
class FrameParser extends PushPullStage[ByteString, ByteString] {
// this holds the received but not yet parsed bytes
var stash = ByteString.empty
// this holds the current message length or -1 if at a boundary
var needed = -1
override def onPush(bytes: ByteString, ctx: Context[ByteString]) = {
stash ++= bytes
run(ctx)
}
override def onPull(ctx: Context[ByteString]) = run(ctx)
override def onUpstreamFinish(ctx: Context[ByteString]) =
if (stash.isEmpty) ctx.finish()
else ctx.absorbTermination() // we still have bytes to emit
private def run(ctx: Context[ByteString]): SyncDirective =
if (needed == -1) {
// are we at a boundary? then figure out next length
if (stash.length < 4) pullOrFinish(ctx)
else {
needed = stash.iterator.getInt
stash = stash.drop(4)
run(ctx) // cycle back to possibly already emit the next chunk
}
} else if (stash.length < needed) {
// we are in the middle of a message, need more bytes
pullOrFinish(ctx)
} else {
// we have enough to emit at least one message, so do it
val emit = stash.take(needed)
stash = stash.drop(needed)
needed = -1
ctx.push(emit)
}
/*
* After having called absorbTermination() we cannot pull any more, so if we need
* more data we will just have to give up.
*/
private def pullOrFinish(ctx: Context[ByteString]) =
if (ctx.isFinishing) ctx.finish()
else ctx.pull()
}
val outbound = b.add(Flow[ByteString].map(addLengthHeader))
val inbound = b.add(Flow[ByteString].transform(() => new FrameParser))
BidiShape(outbound, inbound)
}
//---------------------------------
implicit val system = ActorSystem()
implicit val materializer = ActorFlowMaterializer()
import scala.concurrent.duration._
val stack = codec.atop(framing)
// test it by plugging it into its own inverse and closing the right end
val pingpong = Flow[Message].collect { case Ping(id) => Pong(id) }
val flow = stack.atop(stack.reversed).join(pingpong)
val result = Source((0 to 9).map(Ping)).via(flow).grouped(20).runWith(Sink.head)
println(Await.result(result, 1.second)) // should ===((0 to 9).map(Pong))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment