Created
April 16, 2015 12:53
-
-
Save luke/11f14281d9d62209f2dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import logging | |
import time | |
import random | |
from multiprocessing import Process | |
from redis import StrictRedis, WatchError | |
""" | |
RedisSortedSetDelayedQueue - uses zset to keep maintain a time ordered queue of items | |
""" | |
class RedisSortedSetDelayedQueue(): | |
def __init__(self, redis=None, key=None, logger=None, | |
enable_keyspace_events=True): | |
if redis is None: | |
redis = StrictRedis() | |
self.redis = redis | |
if key is None: | |
key = "delayed_queue/" + str(int(time.time())) + "." + str(int(random.random() * 1000000)) | |
self.key = key | |
if logger is None: | |
logger = logging.getLogger('RedisSortedSetDelayedQueue') | |
self.logger = logger | |
if enable_keyspace_events and not self._is_keyspace_events_enabled: | |
self.logger.debug("turning on keyspace events") | |
self._enabled_keyspace_events() | |
def _now(self): | |
# TODO: optimzie this so it only calls server once | |
# then uses local time to work out server time | |
ts = self.redis.time()[0] | |
return ts | |
def clear(self): | |
keys = self.redis.keys(self.key+'/*') | |
self.logger.debug(repr(keys)) | |
self.redis.delete(keys) | |
def add(self, item, delay): | |
ts = int(math.ceil(self._now() + delay)) | |
# self.logger.debug('expire at.. %s', ts) | |
with self.redis.pipeline() as pipe: | |
# add item to sorted set, with timestamp as score | |
pipe.zadd(self.key + '/zset', ts, item) | |
# set item with expires time, used for keyspace notifications | |
item_key = self.key + '/items/' + item | |
pipe.set(item_key, ts) # , px=ts | |
pipe.expireat(item_key, ts) | |
pipe.execute() | |
def pop(self): | |
item = None | |
with self.redis.pipeline() as pipe: | |
while 1: | |
try: | |
pipe.watch(self.key) | |
ts = self._now() | |
items = pipe.zrangebyscore(self.key+'/zset', '-inf', ts, num=1, start=0) | |
if not items: | |
pipe.unwatch() | |
break | |
pipe.multi() | |
item = items[0] | |
self.logger.debug('popped: %r' % item) | |
# remote item from sorted set | |
pipe.zrem(self.key+'/zset', item) | |
res = pipe.execute() | |
break | |
except WatchError as ex: | |
# self.logger.warn('watch error %r', ex) | |
time.sleep(0.01) # 10ms delay | |
continue | |
# except Exception as ex: | |
# self.logger.exception(ex) | |
# break | |
return item | |
def start_polling(self, callback, poll_interval=2): | |
""" | |
poll the server for pops | |
while its not server intensive its better to use start_listening | |
which will use keyspace notifications to do the same job | |
""" | |
def loop(self, callback): | |
while(True): | |
item = self.pop() | |
if item: | |
callback(item) | |
else: | |
time.sleep(poll_interval) | |
process = Process(target=loop, args=[self, callback]) | |
self.logger.debug('starting popping process') | |
process.start() | |
return process | |
def _is_keyspace_events_enabled(self): | |
config = self.redis.config_get('notify-keyspace-events').get('notify-keyspace-events', '') | |
return 'A' in config or 'e' in config | |
def _enabled_keyspace_events(self): | |
self.redis.config_set('notify-keyspace-events', 'KEA') | |
def start_listening(self, callback): | |
""" | |
listen to expire events to give indication when to call pop | |
since this can be run by multiple clients only one will win | |
this however removes the need to poll the server for possible pops | |
""" | |
if not self._is_keyspace_events_enabled(): | |
raise Exception(""" | |
Keyspace notifications need to be enabled to use channel mode. | |
Enabled with 'redis-cli config set notify-keyspace-events KEA' | |
""") | |
# rather long winded way to find current db | |
# dont think with will work if select is called after connect | |
current_db = self.redis.connection_pool.connection_kwargs.get('db',0) | |
pubsub = self.redis.pubsub() | |
keyspace_pattern = '__keyspace@%i__:%s/items/*' % (current_db, self.key) | |
self.logger.info('subscribing to %s events' % keyspace_pattern) | |
def popping(self, callback): | |
while 1: | |
item = self.pop() | |
if item is None: | |
break | |
if item: | |
# got one, pass to callback | |
callback(item) | |
def loop(self, pubsub, callback): | |
# clear any pops | |
popping(self, callback) | |
# subscribe to notifications | |
pubsub.psubscribe(keyspace_pattern) | |
while 1: | |
message = pubsub.get_message() | |
if message is None: | |
time.sleep(0.01) # 10ms delay | |
continue | |
if message['data'] != 'expired': | |
continue | |
# we got an expired event, try popping | |
popping(self, callback) | |
process = Process(target=loop, args=[self, pubsub, callback]) | |
self.logger.debug('starting popping process') | |
process.start() | |
return process | |
def test(): | |
q = RedisSortedSetDelayedQueue(key='test') | |
task = 'foo' | |
q.clear() | |
q.add(task, delay=2) | |
q.add('bar', 3) | |
def callback(*args, **kwargs): | |
q.logger.info("got callback %s" % repr((args, kwargs))) | |
assert(q.pop() is None) | |
time.sleep(4) | |
assert(q.pop() == task) | |
assert(q.pop() == 'bar') | |
assert(q.pop() is None) | |
q.add('1', 1) | |
q.add('2', 2) | |
q.add('3', 3) | |
q.add('4', 4) | |
q.start_listening(callback) | |
time.sleep(5) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
test() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment