Skip to content

Instantly share code, notes, and snippets.

@randomstatistic
Created June 15, 2016 17:38
Show Gist options
  • Save randomstatistic/6dc13fc80ebf30b97c09ae52002895b8 to your computer and use it in GitHub Desktop.
Save randomstatistic/6dc13fc80ebf30b97c09ae52002895b8 to your computer and use it in GitHub Desktop.
Leaky bucket implementation in scala
import java.util.concurrent.locks.ReentrantLock
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.concurrent.duration._
class LeakyBucket(dripEvery: FiniteDuration, maxSize: Int) {
require(maxSize > 0, "A bucket must have a size > 0")
private val dripEveryNanos = dripEvery.toNanos
private val lock = new ReentrantLock()
private def withLock[T](f: => T) = {
lock.lock()
try { f } finally { lock.unlock() }
}
private var bucket = 0
private var lastFilled = System.nanoTime()
private def nextFill = lastFilled + dripEveryNanos
private def refillBucket(upTo: Int = maxSize) {
val now = System.nanoTime()
if (now >= nextFill) {
val tokensGeneratedSinceLastRun = ((now - lastFilled) / dripEveryNanos).toInt
lastFilled = lastFilled + tokensGeneratedSinceLastRun * dripEveryNanos
bucket = upTo min (bucket + tokensGeneratedSinceLastRun)
}
}
private def waitForRefills(num: Int) {
val now = System.nanoTime()
if (now < nextFill) {
//Because thread.sleep requires millis
val nextFillMillis = (lastFilled - now + dripEveryNanos * num) / 1000000
// Could compute nextFillMillis with remainder and use the sleep(millis, nanos) api instead, I suppose
Thread.sleep(nextFillMillis max 1)
}
}
def drain() {
withLock {
refillBucket()
bucket = 0
}
}
// -- Sync apis --
def awaitToken(num: Int = 1) {
require(num > 0)
withLock {
assert(bucket >= 0)
refillBucket()
if (bucket >= num) {
bucket = bucket - num
}
else {
val soFar = bucket
val remaining = num - soFar
bucket = 0
waitForRefills(remaining min maxSize)
awaitToken(remaining)
}
}
}
def rateLimited[T](num: Int = 1)(f: => T): T = {
awaitToken(num)
f
}
def iterator(size: Int, tokens: Int = 1) = Range(0,size).iterator.map(i => rateLimited(tokens){ i })
// -- Async apis --
def getToken(num: Int)(implicit ec: ExecutionContext): Future[Unit] = Future{ awaitToken(num) }(ec)
def rateLimitedAsync[T](num: Int)(f: => T)(implicit ec: ExecutionContext): Future[T] = {
getToken(num).map( _ => f )
}
// Note, doesn't block the parameter future from executing, only anything chained off of that
def rateLimitedFuture[T](num: Int)(f: Future[T])(implicit ec: ExecutionContext): Future[T] = {
getToken(num).flatMap( _ => f )
}
}
import java.util.concurrent.Executors
import org.scalatest.{FunSpec, Matchers}
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
class TestLeakyBucket extends FunSpec with Matchers {
implicit val ec = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(10))
val times = scala.collection.mutable.HashMap[String, FiniteDuration]()
def time[T](name: String)(f: => T) = {
val t1 = System.nanoTime()
try {
f
} finally {
times += (name -> (System.nanoTime() - t1).nanos)
}
}
def bucketTest(bucketRate: FiniteDuration, bucketSize: Int, initialDelay: FiniteDuration)(block: (LeakyBucket) => Unit): Unit = {
val bucket = new LeakyBucket(bucketRate, bucketSize)
if (initialDelay > 0.seconds) Thread.sleep(initialDelay.toMillis) else bucket.drain()
block(bucket)
}
describe("100ms rate") {
it("should consume a half-full bucket quickly") {
bucketTest(100.millis, 10, 500.millis) { bucket => {
time("half-full bucket") {
bucket.awaitToken(5)
}
times("half-full bucket").toMillis should be < 10L
}}
}
it("should consume a full bucket quickly") {
bucketTest(100.millis, 10, 1100.millis) { bucket => {
time("full bucket") {
bucket.awaitToken(10)
}
times("full bucket").toMillis should be < 10L
}}
}
it("should consume from an empty bucket at the expected rate") {
bucketTest(100.millis, 10, 0.millis) { bucket => {
Range(1, 10).foreach(i => {
val timerName = "empty bucket run " + i
time(timerName) {
bucket.awaitToken()
}
times(timerName).toMillis should be(100L +- 15)
})
}}
}
it("asking a full bucket for more than it has should not be quick") {
bucketTest(100.millis, 10, 1100.millis) { bucket => {
time("full bucket overflow") {
bucket.awaitToken(11)
}
times("full bucket overflow").toMillis should be(100L +- 15)
}}
}
}
describe("10ms rate") {
it("should consume from an empty bucket at the expected rate") {
bucketTest(10.millis, 100, 0.millis) { bucket => {
time("throughput") {
Range(0, 100).foreach(i => {
bucket.awaitToken()
})
}
times("throughput").toMillis should be (1000L +- 25)
}}
}
it("gets the expected rate with concurrent consumers") {
var futures = List[Future[Unit]]()
bucketTest(10.millis, 100, 0.millis) { bucket => {
time("concurrent throughput") {
Range(0, 100).foreach( _ => {
futures = futures :+ bucket.getToken(1)
})
Await.ready(Future.sequence(futures), 2.seconds)
}
times("concurrent throughput").toMillis should be (1000L +- 25)
}}
}
}
describe("rate limited actions") {
it ("should limit synchronous rate") {
bucketTest(10.millis, 1, 0.millis) { bucket => {
val results = time("rate limited loop") {
for (i <- Range(0, 10)) yield {
bucket.rateLimited() {
i // some thrilling computation
}
}
}
times("rate limited loop").toMillis should be (100L +- 5)
results should contain theSameElementsInOrderAs Range(0,10)
}}
}
it("should rate limit async blocks") {
bucketTest(50.millis, 1, 0.millis) { bucket => {
val f = bucket.rateLimitedAsync(1) {
1
}
val result = time("future completion") {
Await.result(f, 1.seconds)
}
times("future completion").toMillis should be (50L +- 5)
result should be(1)
}}
}
it("should rate limit future completion") {
bucketTest(10.millis, 1, 0.millis) { bucket => {
val f = bucket.rateLimitedFuture(5)(Future.successful(1))
val result = time("future completion") {
Await.result(f, 1.seconds)
}
times("future completion").toMillis should be (50L +- 5)
result should be(1)
}}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment