Skip to content

Instantly share code, notes, and snippets.

@tg44
Created April 29, 2021 10:17
Show Gist options
  • Save tg44/aa1d279b247d74e0ebca1489dc643410 to your computer and use it in GitHub Desktop.
Save tg44/aa1d279b247d74e0ebca1489dc643410 to your computer and use it in GitHub Desktop.
AdaptiveQueueSource akka streams extension
package akka.streams
import akka.Done
import akka.stream.OverflowStrategies.{Backpressure, DropBuffer, DropHead, DropNew, DropTail, Fail}
import akka.stream.{
Attributes,
BufferOverflowException,
Outlet,
OverflowStrategy,
QueueOfferResult,
SourceShape,
StreamDetachedException,
}
import akka.stream.impl.Buffer
import akka.stream.impl.Stages.DefaultAttributes
import akka.stream.scaladsl.{Source, SourceQueueWithComplete}
import akka.stream.stage.{GraphStageLogic, GraphStageWithMaterializedValue, OutHandler, StageLogging}
import akka.streams.AdaptiveQueueSource.SourceQueueWithCompleteAndSize
import scala.concurrent.{Future, Promise}
object AdaptiveQueueSource {
sealed trait Input[+T]
final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T]
case object Completion extends Input[Nothing]
final case class Failure(ex: Throwable) extends Input[Nothing]
def priorityQueue[T: Ordering](
bufferSize: Int,
overflowStrategy: OverflowStrategy,
): Source[T, SourceQueueWithCompleteAndSize[T]] =
Source.fromGraph(new AdaptiveQueueSource(
() => new FixedSizePriorityBuffer(bufferSize),
overflowStrategy,
).withAttributes(DefaultAttributes.queueSource))
def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueueWithCompleteAndSize[T]] =
Source.fromGraph(new AdaptiveQueueSource(() => Buffer[T](bufferSize, 1000000000), overflowStrategy).withAttributes(
DefaultAttributes.queueSource
))
class FixedSizePriorityBuffer[T: Ordering](val capacity: Int) extends Buffer[T] {
override def toString = s"PriorityBuffer($capacity)(${buffer.clone.dequeueAll})"
private val buffer = collection.mutable.PriorityQueue.empty[T]
def used: Int = buffer.size
def isFull: Boolean = used >= capacity
def nonFull: Boolean = used < capacity
def remainingCapacity: Int = {
val rem = capacity - used
if(rem > 0) rem else 0
}
def isEmpty: Boolean = used == 0
def nonEmpty: Boolean = used != 0
def enqueue(elem: T): Unit = buffer.enqueue(elem)
def peek(): T = buffer.head
def dequeue(): T = buffer.dequeue()
def clear(): Unit = buffer.clear()
def dropHead(): Unit = buffer.dequeue()
def dropTail(): Unit = buffer.dropRight(1)
}
trait SourceQueueWithCompleteAndSize[T] extends SourceQueueWithComplete[T] {
def used: Int
def capacity: Int
def isFull: Boolean
def isEmpty: Boolean
}
}
final class AdaptiveQueueSource[T](queueCreator: () => Buffer[T], overflowStrategy: OverflowStrategy)
extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueueWithCompleteAndSize[T]] {
import AdaptiveQueueSource._
val out = Outlet[T]("queueSource.out")
override val shape: SourceShape[T] = SourceShape.of(out)
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val completion = Promise[Done]
val stageLogic =
new GraphStageLogic(shape) with OutHandler with SourceQueueWithCompleteAndSize[T] with StageLogging {
override protected def logSource: Class[_] = classOf[AdaptiveQueueSource[_]]
val buffer: Buffer[T] = queueCreator()
var pendingOffer: Option[Offer[T]] = None
var terminating = false
override def postStop(): Unit = {
val exception = new StreamDetachedException()
completion.tryFailure(exception)
}
private def enqueueAndSuccess(offer: Offer[T]): Unit = {
buffer.enqueue(offer.elem)
offer.promise.success(QueueOfferResult.Enqueued)
}
private def bufferElem(offer: Offer[T]): Unit = {
if(!buffer.isFull) {
enqueueAndSuccess(offer)
} else
overflowStrategy match {
case s: DropHead =>
log.log(
s.logLevel,
"Dropping the head element because buffer is full and overflowStrategy is: [DropHead]",
)
buffer.dropHead()
enqueueAndSuccess(offer)
case s: DropTail =>
log.log(
s.logLevel,
"Dropping the tail element because buffer is full and overflowStrategy is: [DropTail]",
)
buffer.dropTail()
enqueueAndSuccess(offer)
case s: DropBuffer =>
log.log(
s.logLevel,
"Dropping all the buffered elements because buffer is full and overflowStrategy is: [DropBuffer]",
)
buffer.clear()
enqueueAndSuccess(offer)
case s: DropNew =>
log.log(
s.logLevel,
"Dropping the new element because buffer is full and overflowStrategy is: [DropNew]",
)
offer.promise.success(QueueOfferResult.Dropped)
case s: Fail =>
log.log(s.logLevel, "Failing because buffer is full and overflowStrategy is: [Fail]")
val bufferOverflowException =
BufferOverflowException(s"Buffer overflow (max capacity was: ${buffer.capacity})!")
offer.promise.success(QueueOfferResult.Failure(bufferOverflowException))
completion.failure(bufferOverflowException)
failStage(bufferOverflowException)
case s: Backpressure =>
log.log(s.logLevel, "Backpressuring because buffer is full and overflowStrategy is: [Backpressure]")
pendingOffer match {
case Some(_) =>
offer.promise.failure(
new IllegalStateException(
"You have to wait for the previous offer to be resolved to send another request"
)
)
case None =>
pendingOffer = Some(offer)
}
}
}
private val callback = getAsyncCallback[Input[T]] {
case Offer(_, promise) if terminating =>
promise.success(QueueOfferResult.Dropped)
case offer @ Offer(elem, promise) =>
bufferElem(offer)
if(isAvailable(out)) push(out, buffer.dequeue())
case Completion =>
if(buffer.nonEmpty || pendingOffer.nonEmpty) terminating = true
else {
completion.success(Done)
completeStage()
}
case Failure(ex) =>
completion.failure(ex)
failStage(ex)
}
setHandler(out, this)
override def onDownstreamFinish(): Unit = {
pendingOffer match {
case Some(Offer(_, promise)) =>
promise.success(QueueOfferResult.QueueClosed)
pendingOffer = None
case None => // do nothing
}
completion.success(Done)
completeStage()
}
override def onPull(): Unit = {
if(buffer.nonEmpty) {
push(out, buffer.dequeue())
pendingOffer match {
case Some(offer) =>
enqueueAndSuccess(offer)
pendingOffer = None
case None => //do nothing
}
if(terminating && buffer.isEmpty) {
completion.success(Done)
completeStage()
}
}
}
override def watchCompletion() = completion.future
override def offer(element: T): Future[QueueOfferResult] = {
val p = Promise[QueueOfferResult]
callback
.invokeWithFeedback(Offer(element, p))
.onComplete {
case scala.util.Success(_) =>
case scala.util.Failure(e) => p.tryFailure(e)
}(akka.dispatch.ExecutionContexts.sameThreadExecutionContext)
p.future
}
override def complete(): Unit = callback.invoke(Completion)
override def fail(ex: Throwable): Unit = callback.invoke(Failure(ex))
override def used: Int = buffer.used
override def capacity: Int = buffer.capacity
override def isFull: Boolean = buffer.isFull
override def isEmpty: Boolean = buffer.isEmpty
}
(stageLogic, stageLogic)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment