Created
April 12, 2020 07:01
-
-
Save x0a1b/320031de7075fb47009c6bb90a2191f8 to your computer and use it in GitHub Desktop.
Rolling window rate limiter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.lettuce.core.ScriptOutputType | |
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection | |
import kotlinx.coroutines.future.asDeferred | |
import java.time.Duration | |
import java.util.UUID | |
class ClusteredSlidingWindowRateLimiter( | |
redisClient: StatefulRedisClusterConnection<String, String> | |
) : RateLimiter { | |
companion object { | |
private val ACQUIRE_LUA_SCRIPT = """ | |
-- Load params | |
local key = KEYS[1] | |
local max_count = tonumber(ARGV[1]) | |
local duration_ms = tonumber(ARGV[2]) | |
local now_ms = tonumber(ARGV[3]) | |
local interaction_id = ARGV[4] | |
-- Calculate start time and expiry time | |
local start_ms = now_ms - duration_ms | |
local expire_ms = duration_ms | |
-- A nil interaction_id means no mutation rolling window | |
-- in that case we are going to only count based on range score | |
if interaction_id == nil then | |
local res = redis.call('ZRANGEBYSCORE', key, start_ms, now_ms, 'LIMIT', 0, max_count) | |
if res ~= nil and type(res) == 'table' then | |
return #res | |
end | |
return 0 | |
end | |
-- Remove everything before start of time window | |
-- Add yourself set expiry and get cardinality | |
redis.call('ZREMRANGEBYSCORE', key, '-inf', start_ms) | |
redis.call('ZADD', key, now_ms, interaction_id) | |
redis.call('PEXPIRE', key, expire_ms) | |
local cardinality = redis.call('ZCARD', key) | |
local trim_length = cardinality - max_count - 1 | |
if trim_length >= 0 then | |
redis.call('ZREMRANGEBYRANK', key, 0, trim_length) | |
return max_count | |
end | |
return cardinality | |
""".trimIndent() | |
} | |
private val cmd by lazy { | |
redisClient.async() | |
} | |
private val acquireScriptDeferred by lazy { | |
cmd.scriptLoad(ACQUIRE_LUA_SCRIPT).asDeferred() | |
} | |
override suspend fun acquire(action: String, maxPermits: Long, bucketWindow: Duration): Boolean { | |
val sha = acquireScriptDeferred.await() | |
val now = System.nanoTime() / 1_000_000 | |
val count = cmd.evalsha<Long>( | |
sha, | |
ScriptOutputType.INTEGER, | |
arrayOf(action), | |
"$maxPermits", | |
"${bucketWindow.toMillis()}", | |
"$now", | |
UUID.randomUUID().toString() | |
).asDeferred().await() | |
return count < maxPermits | |
} | |
override suspend fun hasPermits(action: String, maxPermits: Long, bucketWindow: Duration): Boolean { | |
val sha = acquireScriptDeferred.await() | |
val now = System.nanoTime() / 1_000_000 | |
val count = cmd.evalsha<Long>( | |
sha, | |
ScriptOutputType.INTEGER, | |
arrayOf(action), | |
"$maxPermits", | |
"${bucketWindow.toMillis()}", | |
"$now" | |
).asDeferred().await() | |
return count < maxPermits | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment