Skip to content

Instantly share code, notes, and snippets.

@toaco
Created November 3, 2017 03:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save toaco/a5e0f7fb70ce7bb4f34f1ea27eeb6df4 to your computer and use it in GitHub Desktop.
Save toaco/a5e0f7fb70ce7bb4f34f1ea27eeb6df4 to your computer and use it in GitHub Desktop.
Hash stack in Redis which could expire items
import datetime
import redis
class DictStack(object):
def __init__(self, key_name, redis_option):
self._db = redis.StrictRedis(**redis_option)
self._key_name = key_name
self._script = self._register_script()
def _register_script(self):
lua_script = """
while (true)
do
if (redis.call('EXISTS', redis.call('LRANGE', KEYS[1], 0, 0)[1]) == 1) then
break
else
if redis.call('LLEN', KEYS[1]) == 0 then
break
end
redis.call('LPOP', KEYS[1])
end
end
local ret = { 'ok' }
ret['ok'] = 'OK'
return ret
"""
return self._db.register_script(lua_script)
def _clear_expired_records(self):
self._script(keys=[self._key_name])
def top(self, limit=None, offset=0):
self._clear_expired_records()
result = []
if limit == 0:
return result
end = -1 - offset
if limit is None:
start = 0
else:
start = end - limit + 1
elements = self._db.lrange(self._key_name, start, end)
start_idx = self.count() - offset
for i, element in enumerate(reversed(elements)):
record = self._db.hgetall(element)
record['id'] = start_idx - i
result.append(record)
return result
def count(self):
self._clear_expired_records()
return self._db.llen(self._key_name)
def push(self, record, expire=None):
num = self._db.llen(self._key_name)
hash_key = '{}:{}'.format(self._key_name, num)
with self._db.pipeline() as pipe:
while 1:
try:
pipe.watch(self._key_name)
pipe.multi()
self._db.rpush(self._key_name, hash_key)
self._db.hmset(hash_key, record)
if expire:
if isinstance(expire, int):
self._db.expire(hash_key, expire)
elif isinstance(expire, datetime.datetime):
self._db.expireat(hash_key, expire)
else:
raise ValueError
pipe.execute()
break
except redis.WatchError:
continue
def clear(self):
keys = self._db.keys('{}*'.format(self._key_name))
if keys:
self._db.delete(*keys)
import pytest
@pytest.fixture()
def stack():
stack = DictStack('test_dict_stack',
{'host': 'localhost', 'port': 6379, 'db': 0})
stack.clear()
yield stack
def test_empty_stack(stack):
assert stack.top() == []
assert stack.top(1) == []
assert stack.top(offset=1) == []
assert stack.top(1, offset=1) == []
def test_stack_with_one_record(stack):
stack.push({'123': 456})
assert stack.top() == [{'123': '456', 'id': 1}]
assert stack.top(1) == [{'123': '456', 'id': 1}]
assert stack.top(2) == [{'123': '456', 'id': 1}]
assert stack.top(offset=1) == []
assert stack.top(1, offset=1) == []
assert stack.top(2, offset=1) == []
def test_stack_with_two_records(stack):
stack.push({'123': 456})
stack.push({'234': 567})
assert stack.top() == [{'234': '567', 'id': 2}, {'123': '456', 'id': 1}]
assert stack.top(1) == [{'234': '567', 'id': 2}]
assert stack.top(2) == [{'234': '567', 'id': 2}, {'123': '456', 'id': 1}]
assert stack.top(offset=1) == [{'123': '456', 'id': 1}]
assert stack.top(1, offset=1) == [{'123': '456', 'id': 1}]
assert stack.top(2, offset=1) == [{'123': '456', 'id': 1}]
assert stack.top(0, offset=2) == []
assert stack.top(1, offset=2) == []
assert stack.top(2, offset=2) == []
def test_record_expired_in_stack(stack):
import time
stack.push({'123': 456}, expire=1)
stack.push({'234': 567},
expire=datetime.datetime.now() + datetime.timedelta(seconds=2))
assert stack.count() == 2
time.sleep(1)
assert stack.count() == 1
time.sleep(1)
assert stack.count() == 0
if __name__ == '__main__':
pytest.main([__file__, '-v'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment