Created
May 12, 2021 13:36
-
-
Save hcoa/7c51f07ee88abaa3d74d93e4926d9826 to your computer and use it in GitHub Desktop.
Scala lock-free rate limiter example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.atomic.{AtomicLong, AtomicReference} | |
import scala.concurrent.duration._ | |
trait Clock { | |
def now: Long | |
} | |
class SystemClock extends Clock { | |
def now: Long = System.nanoTime() | |
} | |
class LockFreeRateLimiter(bucketConfig: BucketConfig, clock: Clock) { | |
private val bucketStateRef: AtomicReference[TokenBucket] = | |
new AtomicReference(TokenBucket(bucketConfig, clock.now)) | |
def tryConsume(tokens: Int): Boolean = { | |
var consumed: Boolean = false | |
var underCapacity: Boolean = true | |
do { | |
val prevState = bucketStateRef.get() | |
val newState = prevState.copy() | |
newState.refill(bucketConfig, clock.now) | |
if (newState.availableTokens < tokens) { | |
underCapacity = false | |
} else { | |
newState.availableTokens -= tokens | |
if (bucketStateRef.compareAndSet(prevState, newState)) { | |
consumed = true | |
} | |
} | |
} while (underCapacity && !consumed) | |
consumed | |
} | |
def getAvailableTokens: Long = bucketStateRef.get().availableTokens | |
} | |
case class BucketConfig( | |
capacity: Long, | |
refillRate: Long, | |
period: FiniteDuration | |
) | |
object BucketConfig { | |
def from(capacity: Long, period: FiniteDuration): BucketConfig = | |
BucketConfig(capacity, period.toNanos / capacity, period) | |
} | |
class TokenBucket private ( | |
var availableTokens: Long, | |
var lastRefillNanoTime: Long | |
) { | |
def refill(bucketConfig: BucketConfig, nanoTime: Long): Unit = { | |
val sinceLastRefillNanos = nanoTime - lastRefillNanoTime | |
if (sinceLastRefillNanos >= bucketConfig.refillRate) { | |
val tokensSinceLastRefill = sinceLastRefillNanos / bucketConfig.refillRate | |
availableTokens = | |
Math.min(bucketConfig.capacity, availableTokens + tokensSinceLastRefill) | |
lastRefillNanoTime += tokensSinceLastRefill * bucketConfig.refillRate | |
} | |
} | |
def copy(): TokenBucket = new TokenBucket(availableTokens, lastRefillNanoTime) | |
} | |
object TokenBucket { | |
def apply(bucketConfig: BucketConfig, nowNanos: Long) = | |
new TokenBucket(bucketConfig.capacity, nowNanos) | |
} | |
object LockFreeRateLimiter { | |
def apply(capacity: Long, window: FiniteDuration): LockFreeRateLimiter = | |
apply(capacity, window, new SystemClock) | |
def apply( | |
capacity: Long, | |
window: FiniteDuration, | |
clock: Clock | |
): LockFreeRateLimiter = | |
new LockFreeRateLimiter(BucketConfig.from(capacity, window), clock) | |
} | |
object LockFreeRateLimiterApp { | |
def main(args: Array[String]): Unit = { | |
val lf = LockFreeRateLimiter(10, 1.second) | |
val start = System.currentTimeMillis() | |
val consumed = new AtomicLong() | |
val rejected = new AtomicLong() | |
while (System.currentTimeMillis() - start < 2000) { | |
if (lf.tryConsume(1)) { | |
consumed.addAndGet(1) | |
} else { | |
rejected.addAndGet(1) | |
} | |
} | |
println(s"consumed $consumed, rejected $rejected") | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalatest.freespec.AnyFreeSpec | |
import org.scalatest.matchers.should.Matchers | |
import scala.concurrent.duration._ | |
class TestClock extends Clock { | |
private var currentTime = 1_000_000_000L | |
override def now: Long = currentTime | |
def forward(by: Long): Unit = currentTime += by | |
} | |
class LockFreeRateLimiterSpec extends AnyFreeSpec with Matchers { | |
"LockFreeRateLimiter" - { | |
"should consume available tokens" - { | |
val cap = 1 | |
val period = 100.milliseconds | |
val testClock = new TestClock | |
val lf = LockFreeRateLimiter(cap, period, testClock) | |
lf.tryConsume(cap) shouldBe true | |
lf.getAvailableTokens shouldBe 0 | |
testClock.forward(period.toNanos) | |
lf.tryConsume(cap) shouldBe true | |
lf.getAvailableTokens shouldBe 0 | |
} | |
"after half period, half of tokens should be available" - { | |
val cap = 10 | |
val period = 100.milliseconds | |
val testClock = new TestClock | |
val lf = LockFreeRateLimiter(cap, period, testClock) | |
lf.tryConsume(cap) shouldBe true | |
lf.getAvailableTokens shouldBe 0 | |
testClock.forward(period.toNanos / 2) | |
lf.tryConsume(cap / 2) shouldBe true | |
lf.getAvailableTokens shouldBe 0 | |
testClock.forward(period.toNanos) | |
lf.tryConsume(10) shouldBe true | |
} | |
} | |
} | |
class TokenBucketSpec extends AnyFreeSpec with Matchers { | |
"TokenBucket" - { | |
"successfully refill" - { | |
"after consume" in { | |
val cap = 1 | |
val period = 10.milliseconds | |
val conf = BucketConfig.from(cap, period) | |
val testClock = new TestClock | |
val tb = TokenBucket(conf, testClock.now) | |
tb.availableTokens shouldBe cap | |
tb.availableTokens = 0 | |
testClock.forward(period.toNanos) | |
tb.refill(conf, testClock.now) | |
tb.availableTokens shouldBe cap | |
} | |
"a portion of the capacity" in { | |
val cap = 20 | |
val period = 100.milliseconds | |
val conf = BucketConfig.from(cap, period) | |
val testClock = new TestClock | |
val tb = TokenBucket(conf, testClock.now) | |
tb.availableTokens shouldBe cap | |
tb.availableTokens = 0 | |
val partOfTime = 1 / 10.0 | |
testClock.forward((period.toNanos * partOfTime).toLong) | |
tb.refill(conf, testClock.now) | |
tb.availableTokens shouldBe (cap * partOfTime).toLong | |
} | |
} | |
"check for big numbers and corner cases" in {} | |
"no refill" - { | |
"if no time has passed" in { | |
val cap = 10 | |
val period = 100.milliseconds | |
val conf = BucketConfig.from(cap, period) | |
val testClock = new TestClock | |
val tb = TokenBucket(conf, testClock.now) | |
tb.availableTokens = 0 | |
tb.refill(conf, testClock.now) | |
tb.availableTokens shouldBe 0 | |
} | |
"if passed less time than refill rate" in { | |
val cap = 10 | |
val period = 100.milliseconds | |
val conf = BucketConfig.from(cap, period) | |
val testClock = new TestClock | |
val tb = TokenBucket(conf, testClock.now) | |
tb.availableTokens = 0 | |
testClock.forward(period.toNanos / 100) | |
tb.availableTokens shouldBe 0 | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment