Skip to content

Instantly share code, notes, and snippets.

@josiahcarlson
Last active April 17, 2022 09:22
Show Gist options
  • Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
Regular and sliding window rate limiting to accompany two blog posts.
'''
rate_limit2.py
Copyright 2014, Josiah Carlson - josiah.carlson@gmail.com
Released under the MIT license
This module intends to show how to perform standard and sliding-window rate
limits as a companion to the two articles posted on Binpress entitled
"Introduction to rate limiting with Redis", parts 1 and 2:
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis/155
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis-part-2/166
... which will (or have already been) reposted on my personal blog at least 2
weeks after their original binpress.com posting:
http://www.dr-josiah.com
'''
import json
import time
from flask import g, request
def get_identifiers():
ret = ['ip:' + request.remote_addr]
if g.user.is_authenticated():
ret.append('user:' + g.user.get_id())
return ret
def over_limit(conn, duration=3600, limit=240):
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
count = conn.incr(key)
conn.expire(key, duration)
if count > limit:
return True
return False
def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
for duration, limit in limits:
if over_limit(conn, duration, limit):
return True
return False
def over_limit(conn, duration=3600, limit=240):
# Replaces the earlier over_limit() function and reduces round trips with
# pipelining.
pipe = conn.pipeline(transaction=True)
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
pipe.incr(key)
pipe.expire(key, duration)
if pipe.execute()[0] > limit:
return True
return False
def over_limit_multi_lua(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
if not hasattr(conn, 'over_limit_lua'):
conn.over_limit_lua = conn.register_script(over_limit_multi_lua_)
return conn.over_limit_lua(
keys=get_identifiers(), args=[json.dumps(limits), time.time()])
over_limit_multi_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
for i, limit in ipairs(limits) do
local duration = limit[1]
local bucket = ':' .. duration .. ':' .. math.floor(now / duration)
for j, id in ipairs(KEYS) do
local key = id .. bucket
local count = redis.call('INCR', key)
redis.call('EXPIRE', key, duration)
if tonumber(count) > limit[2] then
return 1
end
end
end
return 0
'''
def over_limit_sliding_window(conn, weight=1, limits=[(1, 10), (60, 120), (3600, 240, 60)], redis_time=False):
if not hasattr(conn, 'over_limit_sliding_window_lua'):
conn.over_limit_sliding_window_lua = conn.register_script(over_limit_sliding_window_lua_)
now = conn.time()[0] if redis_time else time.time()
return conn.over_limit_sliding_window_lua(
keys=get_identifiers(), args=[json.dumps(limits), now, weight])
over_limit_sliding_window_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
local weight = tonumber(ARGV[3] or '1')
local longest_duration = limits[1][1] or 0
local saved_keys = {}
-- handle cleanup and limit checks
for i, limit in ipairs(limits) do
local duration = limit[1]
longest_duration = math.max(longest_duration, duration)
local precision = limit[3] or duration
precision = math.min(precision, duration)
local blocks = math.ceil(duration / precision)
local saved = {}
table.insert(saved_keys, saved)
saved.block_id = math.floor(now / precision)
saved.trim_before = saved.block_id - blocks + 1
saved.count_key = duration .. ':' .. precision .. ':'
saved.ts_key = saved.count_key .. 'o'
for j, key in ipairs(KEYS) do
local old_ts = redis.call('HGET', key, saved.ts_key)
old_ts = old_ts and tonumber(old_ts) or saved.trim_before
if old_ts > now then
-- don't write in the past
return 1
end
-- discover what needs to be cleaned up
local decr = 0
local dele = {}
local trim = math.min(saved.trim_before, old_ts + blocks)
for old_block = old_ts, trim - 1 do
local bkey = saved.count_key .. old_block
local bcount = redis.call('HGET', key, bkey)
if bcount then
decr = decr + tonumber(bcount)
table.insert(dele, bkey)
end
end
-- handle cleanup
local cur
if #dele > 0 then
redis.call('HDEL', key, unpack(dele))
cur = redis.call('HINCRBY', key, saved.count_key, -decr)
else
cur = redis.call('HGET', key, saved.count_key)
end
-- check our limits
if tonumber(cur or '0') + weight > limit[2] then
return 1
end
end
end
-- there is enough resources, update the counts
for i, limit in ipairs(limits) do
local saved = saved_keys[i]
for j, key in ipairs(KEYS) do
-- update the current timestamp, count, and bucket count
redis.call('HSET', key, saved.ts_key, saved.trim_before)
redis.call('HINCRBY', key, saved.count_key, weight)
redis.call('HINCRBY', key, saved.count_key .. saved.block_id, weight)
end
end
-- We calculated the longest-duration limit so we can EXPIRE
-- the whole HASH for quick and easy idle-time cleanup :)
if longest_duration > 0 then
for _, key in ipairs(KEYS) do
redis.call('EXPIRE', key, longest_duration)
end
end
return 0
'''
@AoiKuiyuyou
Copy link

Hi I have three questions.

Question 1

In over_limit_sliding_window_lua_, should

if old_ts > now then

at here be

if old_ts > saved.block_id then

because old_ts is the oldest block id, not a timestamp?

Question 2

Should

local trim = math.min(saved.trim_before, old_ts + blocks)

at here be

saved.trim_before = math.min(saved.trim_before, old_ts + blocks)

because later when saving the oldest block id the code uses saved.trim_before

redis.call('HSET', key, saved.ts_key, saved.trim_before)

?

Question 3

Is the purpose of the code

local trim = math.min(saved.trim_before, old_ts + blocks)

at here to limit the number of blocks to trim to be at most blocks?

@apmcodes
Copy link

apmcodes commented Nov 14, 2018

@ciokan

How would you return an actual timestamp instead of 1 to be used in a Retry-After header?

Replace line 157 (return 1) with the below code. We are trying to loop through the present duration blocks and find out the earliest block with a request made and then calculate the time until that request block would become stall and thus allows for new request.

            -- return 1
            local last_attempt
            for last_block = saved.trim_before, saved.block_id, precision do
                local bcount = redis.call('HGET', key, saved.count_key .. last_block)
                if (bcount) then
                    last_attempt = last_block
                    break
                end
            end
            local next_attempt
            if last_attempt then
                next_attempt = (last_attempt + blocks) * precision
            else 
                next_attempt = 0
            end
            return next_attempt

Note: The next_attempt received is UNIX timestamp in seconds and not milliseconds

@josiahcarlson Please review this code for any improvement or bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment