Skip to content

Instantly share code, notes, and snippets.

@hcoa
Created May 12, 2021 13:36
Show Gist options
  • Save hcoa/7c51f07ee88abaa3d74d93e4926d9826 to your computer and use it in GitHub Desktop.
Save hcoa/7c51f07ee88abaa3d74d93e4926d9826 to your computer and use it in GitHub Desktop.
Scala lock-free rate limiter example
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.concurrent.duration._
trait Clock {
def now: Long
}
class SystemClock extends Clock {
def now: Long = System.nanoTime()
}
class LockFreeRateLimiter(bucketConfig: BucketConfig, clock: Clock) {
private val bucketStateRef: AtomicReference[TokenBucket] =
new AtomicReference(TokenBucket(bucketConfig, clock.now))
def tryConsume(tokens: Int): Boolean = {
var consumed: Boolean = false
var underCapacity: Boolean = true
do {
val prevState = bucketStateRef.get()
val newState = prevState.copy()
newState.refill(bucketConfig, clock.now)
if (newState.availableTokens < tokens) {
underCapacity = false
} else {
newState.availableTokens -= tokens
if (bucketStateRef.compareAndSet(prevState, newState)) {
consumed = true
}
}
} while (underCapacity && !consumed)
consumed
}
def getAvailableTokens: Long = bucketStateRef.get().availableTokens
}
case class BucketConfig(
capacity: Long,
refillRate: Long,
period: FiniteDuration
)
object BucketConfig {
def from(capacity: Long, period: FiniteDuration): BucketConfig =
BucketConfig(capacity, period.toNanos / capacity, period)
}
class TokenBucket private (
var availableTokens: Long,
var lastRefillNanoTime: Long
) {
def refill(bucketConfig: BucketConfig, nanoTime: Long): Unit = {
val sinceLastRefillNanos = nanoTime - lastRefillNanoTime
if (sinceLastRefillNanos >= bucketConfig.refillRate) {
val tokensSinceLastRefill = sinceLastRefillNanos / bucketConfig.refillRate
availableTokens =
Math.min(bucketConfig.capacity, availableTokens + tokensSinceLastRefill)
lastRefillNanoTime += tokensSinceLastRefill * bucketConfig.refillRate
}
}
def copy(): TokenBucket = new TokenBucket(availableTokens, lastRefillNanoTime)
}
object TokenBucket {
def apply(bucketConfig: BucketConfig, nowNanos: Long) =
new TokenBucket(bucketConfig.capacity, nowNanos)
}
object LockFreeRateLimiter {
def apply(capacity: Long, window: FiniteDuration): LockFreeRateLimiter =
apply(capacity, window, new SystemClock)
def apply(
capacity: Long,
window: FiniteDuration,
clock: Clock
): LockFreeRateLimiter =
new LockFreeRateLimiter(BucketConfig.from(capacity, window), clock)
}
object LockFreeRateLimiterApp {
def main(args: Array[String]): Unit = {
val lf = LockFreeRateLimiter(10, 1.second)
val start = System.currentTimeMillis()
val consumed = new AtomicLong()
val rejected = new AtomicLong()
while (System.currentTimeMillis() - start < 2000) {
if (lf.tryConsume(1)) {
consumed.addAndGet(1)
} else {
rejected.addAndGet(1)
}
}
println(s"consumed $consumed, rejected $rejected")
}
}
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
import scala.concurrent.duration._
class TestClock extends Clock {
private var currentTime = 1_000_000_000L
override def now: Long = currentTime
def forward(by: Long): Unit = currentTime += by
}
class LockFreeRateLimiterSpec extends AnyFreeSpec with Matchers {
"LockFreeRateLimiter" - {
"should consume available tokens" - {
val cap = 1
val period = 100.milliseconds
val testClock = new TestClock
val lf = LockFreeRateLimiter(cap, period, testClock)
lf.tryConsume(cap) shouldBe true
lf.getAvailableTokens shouldBe 0
testClock.forward(period.toNanos)
lf.tryConsume(cap) shouldBe true
lf.getAvailableTokens shouldBe 0
}
"after half period, half of tokens should be available" - {
val cap = 10
val period = 100.milliseconds
val testClock = new TestClock
val lf = LockFreeRateLimiter(cap, period, testClock)
lf.tryConsume(cap) shouldBe true
lf.getAvailableTokens shouldBe 0
testClock.forward(period.toNanos / 2)
lf.tryConsume(cap / 2) shouldBe true
lf.getAvailableTokens shouldBe 0
testClock.forward(period.toNanos)
lf.tryConsume(10) shouldBe true
}
}
}
class TokenBucketSpec extends AnyFreeSpec with Matchers {
"TokenBucket" - {
"successfully refill" - {
"after consume" in {
val cap = 1
val period = 10.milliseconds
val conf = BucketConfig.from(cap, period)
val testClock = new TestClock
val tb = TokenBucket(conf, testClock.now)
tb.availableTokens shouldBe cap
tb.availableTokens = 0
testClock.forward(period.toNanos)
tb.refill(conf, testClock.now)
tb.availableTokens shouldBe cap
}
"a portion of the capacity" in {
val cap = 20
val period = 100.milliseconds
val conf = BucketConfig.from(cap, period)
val testClock = new TestClock
val tb = TokenBucket(conf, testClock.now)
tb.availableTokens shouldBe cap
tb.availableTokens = 0
val partOfTime = 1 / 10.0
testClock.forward((period.toNanos * partOfTime).toLong)
tb.refill(conf, testClock.now)
tb.availableTokens shouldBe (cap * partOfTime).toLong
}
}
"check for big numbers and corner cases" in {}
"no refill" - {
"if no time has passed" in {
val cap = 10
val period = 100.milliseconds
val conf = BucketConfig.from(cap, period)
val testClock = new TestClock
val tb = TokenBucket(conf, testClock.now)
tb.availableTokens = 0
tb.refill(conf, testClock.now)
tb.availableTokens shouldBe 0
}
"if passed less time than refill rate" in {
val cap = 10
val period = 100.milliseconds
val conf = BucketConfig.from(cap, period)
val testClock = new TestClock
val tb = TokenBucket(conf, testClock.now)
tb.availableTokens = 0
testClock.forward(period.toNanos / 100)
tb.availableTokens shouldBe 0
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment