Skip to content

Instantly share code, notes, and snippets.

@x0a1b
Created April 12, 2020 07:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save x0a1b/320031de7075fb47009c6bb90a2191f8 to your computer and use it in GitHub Desktop.
Save x0a1b/320031de7075fb47009c6bb90a2191f8 to your computer and use it in GitHub Desktop.
Rolling window rate limiter
import io.lettuce.core.ScriptOutputType
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection
import kotlinx.coroutines.future.asDeferred
import java.time.Duration
import java.util.UUID
class ClusteredSlidingWindowRateLimiter(
redisClient: StatefulRedisClusterConnection<String, String>
) : RateLimiter {
companion object {
private val ACQUIRE_LUA_SCRIPT = """
-- Load params
local key = KEYS[1]
local max_count = tonumber(ARGV[1])
local duration_ms = tonumber(ARGV[2])
local now_ms = tonumber(ARGV[3])
local interaction_id = ARGV[4]
-- Calculate start time and expiry time
local start_ms = now_ms - duration_ms
local expire_ms = duration_ms
-- A nil interaction_id means no mutation rolling window
-- in that case we are going to only count based on range score
if interaction_id == nil then
local res = redis.call('ZRANGEBYSCORE', key, start_ms, now_ms, 'LIMIT', 0, max_count)
if res ~= nil and type(res) == 'table' then
return #res
end
return 0
end
-- Remove everything before start of time window
-- Add yourself set expiry and get cardinality
redis.call('ZREMRANGEBYSCORE', key, '-inf', start_ms)
redis.call('ZADD', key, now_ms, interaction_id)
redis.call('PEXPIRE', key, expire_ms)
local cardinality = redis.call('ZCARD', key)
local trim_length = cardinality - max_count - 1
if trim_length >= 0 then
redis.call('ZREMRANGEBYRANK', key, 0, trim_length)
return max_count
end
return cardinality
""".trimIndent()
}
private val cmd by lazy {
redisClient.async()
}
private val acquireScriptDeferred by lazy {
cmd.scriptLoad(ACQUIRE_LUA_SCRIPT).asDeferred()
}
override suspend fun acquire(action: String, maxPermits: Long, bucketWindow: Duration): Boolean {
val sha = acquireScriptDeferred.await()
val now = System.nanoTime() / 1_000_000
val count = cmd.evalsha<Long>(
sha,
ScriptOutputType.INTEGER,
arrayOf(action),
"$maxPermits",
"${bucketWindow.toMillis()}",
"$now",
UUID.randomUUID().toString()
).asDeferred().await()
return count < maxPermits
}
override suspend fun hasPermits(action: String, maxPermits: Long, bucketWindow: Duration): Boolean {
val sha = acquireScriptDeferred.await()
val now = System.nanoTime() / 1_000_000
val count = cmd.evalsha<Long>(
sha,
ScriptOutputType.INTEGER,
arrayOf(action),
"$maxPermits",
"${bucketWindow.toMillis()}",
"$now"
).asDeferred().await()
return count < maxPermits
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment