''' | |
This is a module that defines some helper classes and functions for | |
expiring groups of related keys at the same time. | |
Written July 1-2, 2013 by Josiah Carlson | |
Released into the public domain | |
''' | |
import time | |
import redis | |
class KeyDiscoveryPipeline(object): | |
''' | |
This class is used as a wrapper around a Redis pipeline in | |
Python to discover keys that are being used. This will work | |
for commands where the key to be modified is the first argument | |
to the command. | |
It won't work properly with commands that modify multiple keys | |
at the same time like some forms of DEL, | |
Used like:: | |
>>> conn = redis.Redis() | |
>>> pipe = KeyDiscoveryPipeline(conn.pipeline()) | |
>>> pipe.sadd('foo', 'bar') | |
>>> pipe.execute() # will fail, use one of the subclasses | |
''' | |
def __init__(self, pipeline): | |
self.pipeline = pipeline | |
self.keys = set() | |
def __getattr__(self, attribute): | |
''' | |
This is a bit of Python magic to discover the keys that | |
are being modified. | |
''' | |
def call(*args, **kwargs): | |
if args: | |
self.keys.add(args[0]) | |
return getattr(self.pipeline, attribute)(*args, **kwargs) | |
return call | |
def execute(self): | |
raise NotImplementedError | |
TTL = 86400 # one day | |
def get_user(keys): | |
# we will assume that all keys are of the form: | |
# <user id>:<context>[:<anything else>] | |
for skey in keys: | |
return skey.partition(':')[0] | |
return [] | |
class ExpirePipeline(KeyDiscoveryPipeline): | |
''' | |
Will automatically call EXPIRE on all keys used. | |
''' | |
def execute(self): | |
user = get_user(self.keys) | |
if not user: | |
return user | |
# add all the keys to the expire SET | |
self.keys.add(user + ':expire') | |
self.pipeline.sadd(user + ':expire', *list(self.keys)) | |
# fetch all known keys from the expire SET | |
self.pipeline.smembers(user + ':expire') | |
# get the results | |
result = self.pipeline.execute() | |
# keep the results to return separate | |
ret = result[:-len(self.keys)-1] | |
# update the expiration time for all known keys | |
for key in result[-1]: | |
self.pipeline.expire(key, TTL) | |
self.pipeline.execute() | |
# clear all known keys and return the result | |
self.keys = set() | |
return ret | |
class SetExpirePipeline(KeyDiscoveryPipeline): | |
''' | |
Supposed to be used by a Redis-level C change, but won't | |
work with standard Redis. | |
''' | |
def execute(self): | |
user = get_user(self.keys) | |
if not user: | |
return user | |
# add all of the keys to the expiration SET | |
self.pipeline.sadd(user + ':expire', *list(self.keys)) | |
# this won't work, EXPIRE doesn't take | |
# a 3rd argument - only for show | |
self.pipeline.expire(user + ':expire', TTL, 'keys') | |
try: | |
return self.pipeline.execute()[:-2] | |
finally: | |
self.keys = set() | |
class LuaExpirePipeline(KeyDiscoveryPipeline): | |
''' | |
This is supposed to be used with the expire_user() function to | |
expire user data. | |
''' | |
def execute(self): | |
# This first part is the same as SetExpirePipeline | |
user = get_user(self.keys) | |
if not user: | |
return user | |
self.pipeline.sadd(user + ':expire', *list(self.keys)) | |
# Instead of calling EXPIRE, we'll just add it to the | |
# expiration ZSET | |
self.pipeline.zadd(':expire', **{user: time.time()}) | |
try: | |
return self.pipeline.execute()[:-2] | |
finally: | |
self.keys = set() | |
def script_load(script): | |
''' | |
This function is borrowed from Redis in Action | |
and is MIT licensed. It is provided for convenience. | |
''' | |
sha = [None] | |
def call(conn, keys=[], args=[], force_eval=False): | |
if not force_eval: | |
if not sha[0]: | |
sha[0] = conn.execute_command( | |
"SCRIPT", "LOAD", script, parse="LOAD") | |
try: | |
return conn.execute_command( | |
"EVALSHA", sha[0], len(keys), *(keys+args)) | |
except redis.exceptions.ResponseError as msg: | |
if not msg.args[0].startswith("NOSCRIPT"): | |
raise | |
return conn.execute_command( | |
"EVAL", script, len(keys), *(keys+args)) | |
return call | |
def expire_user(conn, cutoff=None): | |
''' | |
Expire a single user that was updated as part of calls to | |
LuaExpirePipeline. | |
''' | |
# warning: this is not Redis Cluster compatible | |
return expire_user_lua(conn, [], [cutoff or time.time() - TTL]) | |
expire_user_lua = script_load(''' | |
-- fetch the first user with a score before our cutoff | |
local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1) | |
if #key == 0 then | |
return 0 | |
end | |
-- fetch the known keys to delete | |
local keys = redis.call('smembers', key[1] .. ':expire') | |
keys[#keys+1] = key[1] .. ':expire' | |
-- delete the keys and remove the entry from the zset | |
redis.call('del', unpack(keys)) | |
redis.call('zrem', ':expire', key[1]) | |
return 1 | |
''') | |
class LuaExpirePipeline2(KeyDiscoveryPipeline): | |
''' | |
This is supposed to be used with the expire_user2() function to | |
expire user data, and is modified to somewhat reduce execution | |
time. | |
''' | |
def execute(self): | |
# This first part is the same as LuaExpirePipeline | |
user = get_user(self.keys) | |
if not user: | |
return user | |
# Instead of adding this to the ZSET, we'll update the user | |
# metadata entry - make sure it's in the expire SET! | |
self.hset(user + ':info', 'updated', time.time()) | |
self.pipeline.sadd(user + ':expire', *list(self.keys)) | |
try: | |
return self.pipeline.execute()[:-2] | |
finally: | |
self.keys = set() | |
def expire_user2(conn, cutoff=None): | |
''' | |
Expire a single user that was updated as part of calls to | |
LuaExpirePipeline2. | |
''' | |
# warning: this is also not Redis Cluster compatible | |
return expire_user_lua2(conn, [], [cutoff or time.time() - TTL]) | |
expire_user_lua2 = script_load(''' | |
-- same as before | |
local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1) | |
if #key == 0 then | |
return 0 | |
end | |
-- verify that the user data should expire | |
local last = redis.call('hget', key[1] .. ':info', 'updated') | |
if tonumber(last) > tonumber(ARGV[1]) then | |
-- shouldn't expire, so update the ZSET | |
redis.call('zadd', ':expire', last, key[1]) | |
return 1 | |
end | |
local keys = redis.call('smembers', key[1] .. ':expire') | |
keys[#keys+1] = key[1] .. ':expire' | |
redis.call('del', unpack(keys)) | |
redis.call('zrem', ':expire', key[1]) | |
return 1 | |
''') | |
def crappy_test(): | |
conn = redis.Redis(db=15) | |
conn.flushdb() | |
c1 = ExpirePipeline(conn.pipeline(True)) | |
c1.sadd('12:foo', 'bar') | |
c1.hset('12:goo', 'goo', 'baz') | |
c1.execute() | |
for k in conn.keys('*'): | |
print k, conn.ttl(k) | |
conn.flushdb() | |
c2 = LuaExpirePipeline(conn.pipeline(True)) | |
c2.sadd('12:foo', 'bar') | |
c2.hset('12:goo', 'goo', 'baz') | |
c2.execute() | |
for k in conn.keys('*'): | |
print k, conn.ttl(k) | |
print ':', conn.smembers('12:expire') | |
print conn.zrange(':expire', 0, -1, withscores=True) | |
print expire_user(conn, time.time() + 1) | |
for k in conn.keys('*'): | |
print k, conn.ttl(k) | |
print ':', conn.smembers('12:expire') | |
print conn.zrange(':expire', 0, -1, withscores=True) | |
conn.flushdb() | |
# this should be done during user login | |
conn.zadd(':expire', '12', time.time()) | |
c3 = LuaExpirePipeline2(conn.pipeline(True)) | |
c3.sadd('12:foo', 'bar') | |
c3.hset('12:goo', 'goo', 'baz') | |
c3.execute() | |
for k in conn.keys('*'): | |
print k, conn.ttl(k) | |
print ':', conn.smembers('12:expire') | |
print conn.zrange(':expire', 0, -1, withscores=True) | |
print expire_user(conn, time.time() + 1) | |
for k in conn.keys('*'): | |
print k, conn.ttl(k) | |
print ':', conn.smembers('12:expire') | |
print conn.zrange(':expire', 0, -1, withscores=True) | |
if __name__ == '__main__': | |
crappy_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment