Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Regular and sliding window rate limiting to accompany two blog posts.
'''
rate_limit2.py
Copyright 2014, Josiah Carlson - josiah.carlson@gmail.com
Released under the MIT license
This module intends to show how to perform standard and sliding-window rate
limits as a companion to the two articles posted on Binpress entitled
"Introduction to rate limiting with Redis", parts 1 and 2:
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis/155
http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis-part-2/166
... which will (or have already been) reposted on my personal blog at least 2
weeks after their original binpress.com posting:
http://www.dr-josiah.com
'''
import json
import time
from flask import g, request
def get_identifiers():
ret = ['ip:' + request.remote_addr]
if g.user.is_authenticated():
ret.append('user:' + g.user.get_id())
return ret
def over_limit(conn, duration=3600, limit=240):
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
count = conn.incr(key)
conn.expire(key, duration)
if count > limit:
return True
return False
def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
for duration, limit in limits:
if over_limit(conn, duration, limit):
return True
return False
def over_limit(conn, duration=3600, limit=240):
# Replaces the earlier over_limit() function and reduces round trips with
# pipelining.
pipe = conn.pipeline(transaction=True)
bucket = ':%i:%i'%(duration, time.time() // duration)
for id in get_identifiers():
key = id + bucket
pipe.incr(key)
pipe.expire(key, duration)
if pipe.execute()[0] > limit:
return True
return False
def over_limit_multi_lua(conn, limits=[(1, 10), (60, 120), (3600, 240)]):
if not hasattr(conn, 'over_limit_lua'):
conn.over_limit_lua = conn.register_script(over_limit_multi_lua_)
return conn.over_limit_lua(
keys=get_identifiers(), args=[json.dumps(limits), time.time()])
over_limit_multi_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
for i, limit in ipairs(limits) do
local duration = limit[1]
local bucket = ':' .. duration .. ':' .. math.floor(now / duration)
for j, id in ipairs(KEYS) do
local key = id .. bucket
local count = redis.call('INCR', key)
redis.call('EXPIRE', key, duration)
if tonumber(count) > limit[2] then
return 1
end
end
end
return 0
'''
def over_limit_sliding_window(conn, weight=1, limits=[(1, 10), (60, 120), (3600, 240, 60)], redis_time=False):
if not hasattr(conn, 'over_limit_sliding_window_lua'):
conn.over_limit_sliding_window_lua = conn.register_script(over_limit_sliding_window_lua_)
now = conn.time()[0] if redis_time else time.time()
return conn.over_limit_sliding_window_lua(
keys=get_identifiers(), args=[json.dumps(limits), now, weight])
over_limit_sliding_window_lua_ = '''
local limits = cjson.decode(ARGV[1])
local now = tonumber(ARGV[2])
local weight = tonumber(ARGV[3] or '1')
local longest_duration = limits[1][1] or 0
local saved_keys = {}
-- handle cleanup and limit checks
for i, limit in ipairs(limits) do
local duration = limit[1]
longest_duration = math.max(longest_duration, duration)
local precision = limit[3] or duration
precision = math.min(precision, duration)
local blocks = math.ceil(duration / precision)
local saved = {}
table.insert(saved_keys, saved)
saved.block_id = math.floor(now / precision)
saved.trim_before = saved.block_id - blocks + 1
saved.count_key = duration .. ':' .. precision .. ':'
saved.ts_key = saved.count_key .. 'o'
for j, key in ipairs(KEYS) do
local old_ts = redis.call('HGET', key, saved.ts_key)
old_ts = old_ts and tonumber(old_ts) or saved.trim_before
if old_ts > now then
-- don't write in the past
return 1
end
-- discover what needs to be cleaned up
local decr = 0
local dele = {}
local trim = math.min(saved.trim_before, old_ts + blocks)
for old_block = old_ts, trim - 1 do
local bkey = saved.count_key .. old_block
local bcount = redis.call('HGET', key, bkey)
if bcount then
decr = decr + tonumber(bcount)
table.insert(dele, bkey)
end
end
-- handle cleanup
local cur
if #dele > 0 then
redis.call('HDEL', key, unpack(dele))
cur = redis.call('HINCRBY', key, saved.count_key, -decr)
else
cur = redis.call('HGET', key, saved.count_key)
end
-- check our limits
if tonumber(cur or '0') + weight > limit[2] then
return 1
end
end
end
-- there is enough resources, update the counts
for i, limit in ipairs(limits) do
local saved = saved_keys[i]
for j, key in ipairs(KEYS) do
-- update the current timestamp, count, and bucket count
redis.call('HSET', key, saved.ts_key, saved.trim_before)
redis.call('HINCRBY', key, saved.count_key, weight)
redis.call('HINCRBY', key, saved.count_key .. saved.block_id, weight)
end
end
-- We calculated the longest-duration limit so we can EXPIRE
-- the whole HASH for quick and easy idle-time cleanup :)
if longest_duration > 0 then
for _, key in ipairs(KEYS) do
redis.call('EXPIRE', key, longest_duration)
end
end
return 0
'''

rauhs commented Mar 2, 2016

Would be nice if over_limit_sliding_window_lua returned which limit was in effect, useful for different actions on different limits ("require captcha for this limit, reject on that limit"). For this you can just return i instead of 1.

ciokan commented Jan 21, 2017 edited

How would you return an actual timestamp instead of 1 to be used in a Retry-After header?

Owner

josiahcarlson commented Aug 1, 2017

The answer for you @ciokan is you need to modify the Lua script to calculate the delay. Right now it just returns whether you need to wait. https://gist.github.com/josiahcarlson/80584b49da41549a7d5c#file-rate_limit2-py-L157 is the line you are looking for.

Hi I have three questions.

Question 1

In over_limit_sliding_window_lua_, should

if old_ts > now then

at here be

if old_ts > saved.block_id then

because old_ts is the oldest block id, not a timestamp?

Question 2

Should

local trim = math.min(saved.trim_before, old_ts + blocks)

at here be

saved.trim_before = math.min(saved.trim_before, old_ts + blocks)

because later when saving the oldest block id the code uses saved.trim_before

redis.call('HSET', key, saved.ts_key, saved.trim_before)

?

Question 3

Is the purpose of the code

local trim = math.min(saved.trim_before, old_ts + blocks)

at here to limit the number of blocks to trim to be at most blocks?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment