Skip to content

Instantly share code, notes, and snippets.

@Timshel
Last active January 24, 2019 20:29
Show Gist options
  • Save Timshel/5d3c6e56734043525e9a64acc03a4806 to your computer and use it in GitHub Desktop.
Save Timshel/5d3c6e56734043525e9a64acc03a4806 to your computer and use it in GitHub Desktop.
RateLimiter (call per time and parallelism)
import akka.stream.{ActorMaterializer, OverflowStrategy}
import akka.stream.scaladsl.{Keep, Sink, Source}
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.concurrent.duration.FiniteDuration
class RateLimiter(
val limit: Int,
val time: FiniteDuration,
val parallelism: Int,
val bufferSize: Int
)(
implicit
materializer: ActorMaterializer
) {
val (input, out) = Source.actorRef[() => Future[_]](bufferSize, OverflowStrategy.dropHead)
.via(new SlidingThrottle(limit, time))
.mapAsync(parallelism){ call => call() }
.toMat(Sink.ignore)(Keep.both)
.run()
def enqueue[T](call: => Future[T]): Future[T] = {
val promise = Promise[T]
val wrapped = { () => promise.completeWith(call) }
input ! wrapped
promise.future
}
}
class RateLimiterWithTimeout(
limit: Int,
time: FiniteDuration,
timeout: FiniteDuration,
parallelism: Int
)(
implicit
bufferSize: Int = RateLimiterWithTimeout.bufferSize(limit, time, timeout),
ex: ExecutionContext,
system: akka.actor.ActorSystem,
materializer: ActorMaterializer
) extends RateLimiter(limit, time, parallelism, bufferSize)(materializer) {
override def enqueue[T](call: => Future[T]): Future[T] = {
import FutureHelpers.RichFuture
super.enqueue(call).withTimeout(timeout)
}
}
object RateLimiterWithTimeout {
/**
* The size is equal to 90% of the capacity than can be handled during the timeout period.
* We drop the oldest element when overflowing
* It cannot be lower than the limit.
* Ex : limit 100, time: 1.s timeout 1.min => bufferSize = 5400
*/
def bufferSize(limit: Int, time: FiniteDuration, timeout: FiniteDuration): Int =
math.max(limit, ((timeout / time) * limit * 0.9d).toInt)
}
import java.util.concurrent.TimeoutException
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec}
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.SpanSugar._
import scala.concurrent.{Await, Future}
class RateLimiterSpec extends WordSpec with Matchers with ScalaFutures with BeforeAndAfterAll{
import scala.concurrent.ExecutionContext.Implicits.global
implicit val system = akka.actor.ActorSystem()
implicit val materializer = akka.stream.ActorMaterializer()
def fd(span: org.scalatest.time.Span): scala.concurrent.duration.FiniteDuration = {
new scala.concurrent.duration.FiniteDuration(span.length, span.unit)
}
"RateLimiter" should {
"enqueue" in {
val queue = new RateLimiter(10, fd(1.second), 20, 100)
val res = queue.enqueue(Future(12))
assert(res.isReadyWithin(200.millis))
assert(res.futureValue === 12)
}
"limit" in {
val queue = new RateLimiter(1, fd(2.seconds), 2, 100)
(0 until 1).foreach { _ => queue.enqueue(Future { 1 }) }
val res = queue.enqueue(Future { 2 })
val _ = intercept[TimeoutException] { Await.result(res, 1.second) }
assert(res.isReadyWithin(2.seconds))
assert(res.futureValue === 2)
}
"limit burst" in {
val queue = new RateLimiter(10, fd(2.seconds), 20, 100)
val start = System.currentTimeMillis
val delay = 20
val values = Future.sequence(
(0 until 3 * queue.limit).map { i => queue.enqueue(Future { i -> (System.currentTimeMillis - start) }) }
)
assert(values.isReadyWithin(queue.time * 3))
values.futureValue.sortBy(_._1).grouped(queue.limit).zipWithIndex.foreach { case (grouped, index) =>
assert(grouped.head._2 < index * queue.time.toMillis + (index + 1) * delay)
grouped.map(_._2).sliding(2, 1).map {
case Seq(a, b) => assert(b - a < delay)
}
}
}
"limit parallelism" in {
val queue = new RateLimiter(10, fd(3.seconds), 5, 100)
val start = System.currentTimeMillis
val sleep = 1.second
val delay = 20
val values = Future.sequence(
(0 until 3 * queue.limit).map { i => queue.enqueue(Future {
Thread.sleep(sleep.toMillis)
i -> (System.currentTimeMillis - start)
}) }
)
assert(values.isReadyWithin(queue.time * 3))
values.futureValue.sortBy(_._1).grouped(queue.limit).zipWithIndex.foreach { case (groupedL, indexL) =>
groupedL.grouped(queue.parallelism).zipWithIndex.foreach { case (groupedP, indexP) =>
val limit = indexL * queue.time.toMillis + (indexP + 1) * sleep.toMillis + (indexL + indexP + 1) * delay
assert(groupedP.head._2 < limit)
groupedP.map(_._2).sliding(2, 1).map {
case Seq(a, b) => assert(b - a < delay)
}
}
}
}
}
"RateLimiterWithTimeout" should {
"timeout" in {
val queue = new RateLimiterWithTimeout(10, fd(1.minute), fd(1.second), 100)
val res = queue.enqueue(Future { Thread.sleep(1.minute.toMillis) })
val _ = intercept[TimeoutException] {
Await.result(res, 1.hour)
}
}
}
override def afterAll = {
materializer.shutdown
system.terminate().futureValue
()
}
}
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.stage._
import akka.stream._
import scala.concurrent.duration.{ FiniteDuration, _ }
/**
* Implementation of a Throttle with a sliding window
* Inspired by ()
*/
class SlidingThrottle[T](max: Int, per: FiniteDuration) extends SimpleLinearGraphStage[T] {
require(max > 0, "max must be > 0")
require(per.toNanos > 0, "per time must be > 0")
require(per.toNanos >= max, "Rates larger than 1 unit / nanosecond are not supported")
private val nanosPer = per.toNanos
private val timerName: String = "ThrottleTimer"
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
var willStop = false
var emittedTimes = scala.collection.immutable.Queue.empty[Long]
var last: Long = System.nanoTime
var currentElement: T = _
def pushThenLog(elem: T): Unit = {
push(out, elem)
last = System.nanoTime
emittedTimes = emittedTimes :+ last
if( willStop ) completeStage()
}
def schedule(elem: T, nanos: Long): Unit = {
currentElement = elem
scheduleOnce(timerName, nanos.nanos)
}
def receive(elem: T): Unit = {
var now = System.nanoTime
emittedTimes = emittedTimes.dropWhile { t => t + nanosPer < now }
if( emittedTimes.length < max ) pushThenLog(elem)
else schedule(elem, emittedTimes.head + nanosPer - System.nanoTime)
}
// This scope is here just to not retain an extra reference to the handler below.
// We can't put this code into preRestart() because setHandler() must be called before that.
{
val handler = new InHandler with OutHandler {
override def onUpstreamFinish(): Unit =
if (isAvailable(out) && isTimerActive(timerName)) willStop = true
else completeStage()
override def onPush(): Unit = receive(grab(in))
override def onPull(): Unit = pull(in)
}
setHandler(in, handler)
setHandler(out, handler)
// After this point, we no longer need the `handler` so it can just fall out of scope.
}
override protected def onTimer(key: Any): Unit = {
var elem = currentElement
currentElement = null.asInstanceOf[T]
receive(elem)
}
}
override def toString = "Throttle"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment