Skip to content

Instantly share code, notes, and snippets.

@johanandren
Created July 14, 2017 10:05
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save johanandren/b87a9ed63b4c3e95432dc0497fd73fdb to your computer and use it in GitHub Desktop.
Save johanandren/b87a9ed63b4c3e95432dc0497fd73fdb to your computer and use it in GitHub Desktop.
Sample for a custom rate limiting directive for Akka HTTP
/**
* Copyright (C) 2009-2017 Lightbend Inc. <http://www.lightbend.com>
*/
package http
import java.util.concurrent.atomic.AtomicInteger
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.{HttpResponse, StatusCodes, Uri}
import akka.http.scaladsl.server.{Directive0, Rejection, RejectionHandler, Route}
import akka.stream.ActorMaterializer
import akka.http.scaladsl.server.Directives._
import akka.pattern.ask
import akka.util.Timeout
import scala.concurrent.duration._
import scala.util.{Failure, Success}
object RateLimit {
object SlowActor {
case object Ping
case object Pong
}
class SlowActor extends Actor {
import SlowActor._
import context.dispatcher
def receive = {
case Ping =>
// simulate something taking time to respond
context.system.scheduler.scheduleOnce(10.seconds, sender(), Pong)
}
}
case class PathBusyRejection(path: Uri.Path, max: Int) extends Rejection
class Limiter(max: Int) {
// needs to be a thread safe counter since there can be concurrent requests
val concurrentRequests = new AtomicInteger(0)
val limitConcurrentRequests: Directive0 =
extractRequest.flatMap { request =>
if (concurrentRequests.incrementAndGet() > max) {
// we need to decrease it again, and then reject the request
// this means you can use a rejection handler somwhere else, for
// example around the entire Route turning all such rejections
// to the same kind of actual HTTP response there
concurrentRequests.decrementAndGet()
reject(PathBusyRejection(request.uri.path, max))
} else {
mapResponse { response =>
concurrentRequests.decrementAndGet()
response
}
}
}
}
def main(args: Array[String]): Unit = {
// sample usage
implicit val system = ActorSystem()
implicit val materializer = ActorMaterializer()
val slowActor = system.actorOf(Props[SlowActor])
val rejectionHandler = RejectionHandler.newBuilder()
.handle {
case PathBusyRejection(path, max) =>
complete((StatusCodes.EnhanceYourCalm, s"Max concurrent requests for $path reached, please try again later"))
}.result()
// needs to be created outside of the route tree or else
// you get separate instances rather than sharing one
val limiter = new Limiter(max = 2)
val route =
handleRejections(rejectionHandler) {
path("max-2") {
limiter.limitConcurrentRequests {
implicit val timeout: Timeout = 20.seconds
onSuccess(slowActor ? SlowActor.Ping) { _ =>
complete("Done!")
}
}
}
}
import system.dispatcher
Http().bindAndHandle(route, "127.0.0.1", 8080).onComplete {
case Success(_) => println("Listening for requests, call http://127.0.0.1:8080/max-2 to try out")
case Failure(ex) =>
println("Failed to bind to 127.0.0.8080")
ex.printStackTrace()
}
}
}
@davidrwood
Copy link

davidrwood commented Oct 30, 2023

Nice! But it doesn't decrement if there's an exception when processing the route.

This is what I came up with:

package org.your

import java.util.concurrent.atomic.AtomicInteger

import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.{Directive1, Rejection, RejectionHandler, Route}

import scala.concurrent.ExecutionContext

/** A request limiter that will reject concurrently made requests if they exceed a maximum number.
  * Based on https://gist.github.com/johanandren/b87a9ed63b4c3e95432dc0497fd73fdb
  * @param max The maximum number of concurrent requests to allow.
  */
class RequestLimiter(max: Int) {

  private val concurrentRequests = new AtomicInteger()

  /** Wraps [[Route]]s that should be limited, handling the case that an exception is thrown during processing. */
  def limitConcurrentRequests(route: Route)
                             (implicit executionContext: ExecutionContext): Route =
    checkAndLimitRequests { originalRequest => requestContext =>
      val routeResultFuture = route(requestContext)
      routeResultFuture.onComplete(_ => concurrentRequests.decrementAndGet())
      routeResultFuture
    }

  /** Wraps an inner route that should be limited.
    * Note that this does not get called if an exception is thrown during processing.
    */
  private val checkAndLimitRequests: Directive1[HttpRequest] =
    extractRequest.flatMap { request =>
      if (concurrentRequests.incrementAndGet() > max) {
        concurrentRequests.decrementAndGet()
        reject(PathBusyRejection(request.uri.path, max))
      } else {
        provide(request)
      }
    }

}

/** Provides information to a client on why a path was rejected by [[RequestLimiter]]. */
case class PathBusyRejection(path: Uri.Path, maxConcurrentRequests: Int) extends Rejection

object RequestLimiter {

  /** Handles rejections from [[limitConcurrentRequests]] and transforms them into [[StatusCodes.TooManyRequests]]
    * HTTP statuses with a specific message that can be used by the client to detect this kind of rejection.
    */
  val rejectionHandler =
    RejectionHandler.newBuilder.handle {
      case PathBusyRejection(_, max) =>
        complete(StatusCodes.TooManyRequests -> s"Concurrent requests limit exceeded: $max")
    }.result

}

@johanandren
Copy link
Author

Good catch!

You can make it slightly less expensive by using the parasitic execution context for the onComplete so it is called on the same thread completing the request instead of a separate executor enqueue.

@davidrwood
Copy link

Thanks! I've not come across the parasitic execution context.

Seems like it would be this:

import scala.concurrent.ExecutionContext.parasitic
...
routeResultFuture.onComplete(_ => concurrentRequests.decrementAndGet())(parasitic)

There's not enough traffic on this endpoint to worry about it but for others it might be useful.

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment