Skip to content

Instantly share code, notes, and snippets.

@julianhowarth
Created November 17, 2016 17:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save julianhowarth/7287a6e6eaf665dd79307aaff6164cd8 to your computer and use it in GitHub Desktop.
Save julianhowarth/7287a6e6eaf665dd79307aaff6164cd8 to your computer and use it in GitHub Desktop.
Enhance Akka LengthFieldFramingStage to collect additional data whilst frames are built
import java.nio.ByteOrder
import akka.NotUsed
import akka.stream.scaladsl.Flow
import akka.stream.scaladsl.Framing.FramingException
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.util.{ByteIterator, ByteString}
object FramingStage {
private final val bigEndianDecoder: (ByteIterator, Int) ⇒ Int = (bs, length) ⇒ {
var count = length
var decoded = 0
while (count > 0) {
decoded <<= 8
decoded |= bs.next().toInt & 0xFF
count -= 1
}
decoded
}
private final val littleEndianDecoder: (ByteIterator, Int) ⇒ Int = (bs, length) ⇒ {
val highestOctet = (length - 1) << 3
val Mask = ((1L << (length << 3)) - 1).toInt
var count = length
var decoded = 0
while (count > 0) {
decoded >>>= 8
decoded += (bs.next().toInt & 0xFF) << highestOctet
count -= 1
}
decoded & Mask
}
/**
* Creates a Flow that decodes an incoming stream of unstructured byte chunks into a stream of frames, assuming that
* incoming frames have a field that encodes their length.
*
* If the input stream finishes before the last frame has been fully decoded this Flow will fail the stream reporting
* a truncated frame.
*
* @param fieldLength The length of the "size" field in bytes
* @param fieldOffset The offset of the field from the beginning of the frame in bytes
* @param maximumFrameLength The maximum length of allowed frames while decoding. If the maximum length is exceeded
* this Flow will fail the stream. This length *includes* the header (i.e the offset and
* the length of the size field)
* @param byteOrder The ''ByteOrder'' to be used when decoding the field
*/
def lengthField(
fieldLength: Int,
fieldOffset: Int = 0,
maximumFrameLength: Int,
byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN): Flow[ByteString, ByteString, NotUsed] =
lengthField[ByteString, Boolean, ByteString](
fieldLength, fieldOffset, maximumFrameLength, byteOrder,
identity, _ ⇒ false, (_, _) ⇒ false, (_, bs) ⇒ bs)
/**
* Creates a Flow that decodes an incoming stream of messages containing unstructured byte chunks into a stream of
* frames, assuming that incoming frames have a field that encodes their length. The additional data in the messages
* can be aggregated as required and an aggregate value output when the frames have been formed. For example, if the
* incoming messages contained a receipt timestamp with the data chunk, this can be used to out put the decoded data
* along with the timestamp the first chunk was received. Similarly, this can be used to associate an ip address
* with each messgae.
*
* If the input stream finishes before the last frame has been fully decoded this Flow will fail the stream reporting
* a truncated frame.
*
* @param fieldLength The length of the "size" field in bytes
* @param fieldOffset The offset of the field from the beginning of the frame in bytes
* @param maximumFrameLength The maximum length of allowed frames while decoding. If the maximum length is exceeded
* this Flow will fail the stream. This length *includes* the header (i.e the offset and
* the length of the size field)
* @param byteOrder The ''ByteOrder'' to be used when decoding the field
* @param extractor Extracts the ByteString data from the message
* @param seed Provides the first state for an aggregated value using the first message for each frame as a start
* @param aggregate Takes the currently aggregated value and the current pending element to produce a new aggregate
* @param finish Once the current frame is completed, combines the data with the aggregate
* @tparam In the type of the incoming elements
* @tparam Agg the type of the aggregate to build whilst forming a frame
* @tparam Out the type of the outgoing elements
*/
def lengthField[In, Agg, Out](
fieldLength: Int,
fieldOffset: Int,
maximumFrameLength: Int,
byteOrder: ByteOrder,
extractor: In ⇒ ByteString,
seed: In ⇒ Agg,
aggregate: (Agg, In) ⇒ Agg,
finish: (Agg, ByteString) ⇒ Out): Flow[In, Out, NotUsed] = {
require(fieldLength >= 1 && fieldLength <= 4, "Length field length must be 1, 2, 3 or 4.")
Flow[In].via(new LengthFieldFramingStage[In, Agg, Out](
fieldLength, fieldOffset, maximumFrameLength, byteOrder,
extractor, seed, aggregate, finish))
.named("lengthFieldFraming")
}
/**
* This is an extended version of the Akka lib stage which allows additional
* data to be tracked with the incoming data packets e.g. timestamps, connection
* information etc.
*/
final class LengthFieldFramingStage[In, Agg, Out](
val lengthFieldLength: Int,
val lengthFieldOffset: Int,
val maximumFrameLength: Int,
val byteOrder: ByteOrder,
val extractor: In ⇒ ByteString,
val seed: In ⇒ Agg,
val aggregate: (Agg, In) ⇒ Agg,
val finish: (Agg, ByteString) ⇒ Out) extends GraphStage[FlowShape[In, Out]] {
private val minimumChunkSize = lengthFieldOffset + lengthFieldLength
private val intDecoder = byteOrder match {
case ByteOrder.BIG_ENDIAN ⇒ bigEndianDecoder
case ByteOrder.LITTLE_ENDIAN ⇒ littleEndianDecoder
}
private val in = Inlet[In]("LengthFieldFramingStage.in")
private val out = Outlet[Out]("LengthFieldFramingStage.out")
override val shape: FlowShape[In, Out] = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var buffer = ByteString.empty
private var frameSize = Int.MaxValue
private var agg: Agg = null.asInstanceOf[Agg]
/**
* push, and reset frameSize and buffer
*
*/
private def pushFrame(el: In) = {
val emit = buffer.take(frameSize).compact
buffer = buffer.drop(frameSize)
frameSize = Int.MaxValue
val ac = agg
agg = if (buffer.isEmpty) null.asInstanceOf[Agg] else seed(el)
push(out, finish(ac, emit))
if (buffer.isEmpty && isClosed(in)) {
completeStage()
}
}
/**
* try to push downstream, if failed then try to pull upstream
*
*/
private def tryPushFrame(el: In) = {
val buffSize = buffer.size
if (buffSize >= frameSize) {
pushFrame(el)
}
else if (buffSize >= minimumChunkSize) {
val parsedLength = intDecoder(buffer.iterator.drop(lengthFieldOffset), lengthFieldLength)
frameSize = parsedLength + minimumChunkSize
if (frameSize > maximumFrameLength) {
failStage(new FramingException(s"Maximum allowed frame size is $maximumFrameLength but decoded frame header reported size $frameSize"))
}
else if (buffSize >= frameSize) {
pushFrame(el)
}
else tryPull()
}
else tryPull()
}
private def tryPull() = {
if (isClosed(in)) {
failStage(new FramingException("Stream finished but there was a truncated final frame in the buffer"))
}
else pull(in)
}
override def onPush(): Unit = {
val el = grab(in)
buffer ++= extractor(el)
agg = if (agg == null) seed(el) else aggregate(agg, el)
tryPushFrame(el)
}
override def onPull(): Unit = tryPushFrame(null.asInstanceOf[In])
override def onUpstreamFinish(): Unit = {
if (buffer.isEmpty) {
completeStage()
}
else if (isAvailable(out)) {
tryPushFrame(null.asInstanceOf[In])
} // else swallow the termination and wait for pull
}
setHandlers(in, out, this)
}
}
}
// For example, to use with a stream of Instant/ByteString pairs, recording the time of the first packet of the frame:
Flow[(Instant, ByteString)]
.via(FramingStage.lengthField[(Instant, ByteString), Instant, (Instant, ByteString)](
fieldLength = 4,
fieldOffset = 0,
maximumFrameLength = 50000,
byteOrder = ByteOrder.LITTLE_ENDIAN,
extractor = _._2,
seed = _._1,
aggregate = (agg, _) ⇒ agg,
(agg, bs) ⇒ (agg, bs)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment