Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
'''
This is a module that defines some helper classes and functions for
expiring groups of related keys at the same time.
Written July 1-2, 2013 by Josiah Carlson
Released into the public domain
'''
import time
import redis
class KeyDiscoveryPipeline(object):
'''
This class is used as a wrapper around a Redis pipeline in
Python to discover keys that are being used. This will work
for commands where the key to be modified is the first argument
to the command.
It won't work properly with commands that modify multiple keys
at the same time like some forms of DEL,
Used like::
>>> conn = redis.Redis()
>>> pipe = KeyDiscoveryPipeline(conn.pipeline())
>>> pipe.sadd('foo', 'bar')
>>> pipe.execute() # will fail, use one of the subclasses
'''
def __init__(self, pipeline):
self.pipeline = pipeline
self.keys = set()
def __getattr__(self, attribute):
'''
This is a bit of Python magic to discover the keys that
are being modified.
'''
def call(*args, **kwargs):
if args:
self.keys.add(args[0])
return getattr(self.pipeline, attribute)(*args, **kwargs)
return call
def execute(self):
raise NotImplementedError
TTL = 86400 # one day
def get_user(keys):
# we will assume that all keys are of the form:
# <user id>:<context>[:<anything else>]
for skey in keys:
return skey.partition(':')[0]
return []
class ExpirePipeline(KeyDiscoveryPipeline):
'''
Will automatically call EXPIRE on all keys used.
'''
def execute(self):
user = get_user(self.keys)
if not user:
return user
# add all the keys to the expire SET
self.keys.add(user + ':expire')
self.pipeline.sadd(user + ':expire', *list(self.keys))
# fetch all known keys from the expire SET
self.pipeline.smembers(user + ':expire')
# get the results
result = self.pipeline.execute()
# keep the results to return separate
ret = result[:-len(self.keys)-1]
# update the expiration time for all known keys
for key in result[-1]:
self.pipeline.expire(key, TTL)
self.pipeline.execute()
# clear all known keys and return the result
self.keys = set()
return ret
class SetExpirePipeline(KeyDiscoveryPipeline):
'''
Supposed to be used by a Redis-level C change, but won't
work with standard Redis.
'''
def execute(self):
user = get_user(self.keys)
if not user:
return user
# add all of the keys to the expiration SET
self.pipeline.sadd(user + ':expire', *list(self.keys))
# this won't work, EXPIRE doesn't take
# a 3rd argument - only for show
self.pipeline.expire(user + ':expire', TTL, 'keys')
try:
return self.pipeline.execute()[:-2]
finally:
self.keys = set()
class LuaExpirePipeline(KeyDiscoveryPipeline):
'''
This is supposed to be used with the expire_user() function to
expire user data.
'''
def execute(self):
# This first part is the same as SetExpirePipeline
user = get_user(self.keys)
if not user:
return user
self.pipeline.sadd(user + ':expire', *list(self.keys))
# Instead of calling EXPIRE, we'll just add it to the
# expiration ZSET
self.pipeline.zadd(':expire', **{user: time.time()})
try:
return self.pipeline.execute()[:-2]
finally:
self.keys = set()
def script_load(script):
'''
This function is borrowed from Redis in Action
and is MIT licensed. It is provided for convenience.
'''
sha = [None]
def call(conn, keys=[], args=[], force_eval=False):
if not force_eval:
if not sha[0]:
sha[0] = conn.execute_command(
"SCRIPT", "LOAD", script, parse="LOAD")
try:
return conn.execute_command(
"EVALSHA", sha[0], len(keys), *(keys+args))
except redis.exceptions.ResponseError as msg:
if not msg.args[0].startswith("NOSCRIPT"):
raise
return conn.execute_command(
"EVAL", script, len(keys), *(keys+args))
return call
def expire_user(conn, cutoff=None):
'''
Expire a single user that was updated as part of calls to
LuaExpirePipeline.
'''
# warning: this is not Redis Cluster compatible
return expire_user_lua(conn, [], [cutoff or time.time() - TTL])
expire_user_lua = script_load('''
-- fetch the first user with a score before our cutoff
local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1)
if #key == 0 then
return 0
end
-- fetch the known keys to delete
local keys = redis.call('smembers', key[1] .. ':expire')
keys[#keys+1] = key[1] .. ':expire'
-- delete the keys and remove the entry from the zset
redis.call('del', unpack(keys))
redis.call('zrem', ':expire', key[1])
return 1
''')
class LuaExpirePipeline2(KeyDiscoveryPipeline):
'''
This is supposed to be used with the expire_user2() function to
expire user data, and is modified to somewhat reduce execution
time.
'''
def execute(self):
# This first part is the same as LuaExpirePipeline
user = get_user(self.keys)
if not user:
return user
# Instead of adding this to the ZSET, we'll update the user
# metadata entry - make sure it's in the expire SET!
self.hset(user + ':info', 'updated', time.time())
self.pipeline.sadd(user + ':expire', *list(self.keys))
try:
return self.pipeline.execute()[:-2]
finally:
self.keys = set()
def expire_user2(conn, cutoff=None):
'''
Expire a single user that was updated as part of calls to
LuaExpirePipeline2.
'''
# warning: this is also not Redis Cluster compatible
return expire_user_lua2(conn, [], [cutoff or time.time() - TTL])
expire_user_lua2 = script_load('''
-- same as before
local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1)
if #key == 0 then
return 0
end
-- verify that the user data should expire
local last = redis.call('hget', key[1] .. ':info', 'updated')
if tonumber(last) > tonumber(ARGV[1]) then
-- shouldn't expire, so update the ZSET
redis.call('zadd', ':expire', last, key[1])
return 1
end
local keys = redis.call('smembers', key[1] .. ':expire')
keys[#keys+1] = key[1] .. ':expire'
redis.call('del', unpack(keys))
redis.call('zrem', ':expire', key[1])
return 1
''')
def crappy_test():
conn = redis.Redis(db=15)
conn.flushdb()
c1 = ExpirePipeline(conn.pipeline(True))
c1.sadd('12:foo', 'bar')
c1.hset('12:goo', 'goo', 'baz')
c1.execute()
for k in conn.keys('*'):
print k, conn.ttl(k)
print
conn.flushdb()
c2 = LuaExpirePipeline(conn.pipeline(True))
c2.sadd('12:foo', 'bar')
c2.hset('12:goo', 'goo', 'baz')
c2.execute()
for k in conn.keys('*'):
print k, conn.ttl(k)
print ':', conn.smembers('12:expire')
print conn.zrange(':expire', 0, -1, withscores=True)
print expire_user(conn, time.time() + 1)
for k in conn.keys('*'):
print k, conn.ttl(k)
print ':', conn.smembers('12:expire')
print conn.zrange(':expire', 0, -1, withscores=True)
print
conn.flushdb()
# this should be done during user login
conn.zadd(':expire', '12', time.time())
c3 = LuaExpirePipeline2(conn.pipeline(True))
c3.sadd('12:foo', 'bar')
c3.hset('12:goo', 'goo', 'baz')
c3.execute()
for k in conn.keys('*'):
print k, conn.ttl(k)
print ':', conn.smembers('12:expire')
print conn.zrange(':expire', 0, -1, withscores=True)
print expire_user(conn, time.time() + 1)
for k in conn.keys('*'):
print k, conn.ttl(k)
print ':', conn.smembers('12:expire')
print conn.zrange(':expire', 0, -1, withscores=True)
if __name__ == '__main__':
crappy_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment