Skip to content

Instantly share code, notes, and snippets.

@junkafarian
Created April 10, 2012 11:52
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save junkafarian/2350836 to your computer and use it in GitHub Desktop.
Save junkafarian/2350836 to your computer and use it in GitHub Desktop.
Redis session management using a dictionary-like API for session objects
import redis
from zope.interface import Interface, implements
try: #pragma NO COVERAGE
import simplejson as json
except ImportError: #pragma NO COVERAGE
import json
class RedisSession(dict):
""" Provides a lazy interface for presenting data stored in a redis DB as a
JSON object as a Python dictionary.
"""
def __init__(self, r, sessionkey, ttl=None):
self._r = r
self.sessionkey = sessionkey
self.ttl = ttl
self._populate()
def __getitem__(self, key):
self._populate()
if self.ttl is not None:
# keep the session open with all activity
self._r.expire(self.sessionkey, int(self.ttl))
return super(RedisSession, self).__getitem__(key)
def __setitem__(self, key, val):
self._populate()
res = super(RedisSession, self).__setitem__(key, val)
self._flush()
if self.ttl is not None:
# keep the session open with all activity
self._r.expire(self.sessionkey, int(self.ttl))
return res
def __delitem__(self, key):
self._populate()
res = super(RedisSession, self).__delitem__(key)
self._flush()
if self.ttl is not None:
# keep the session open with all activity
self._r.expire(self.sessionkey, int(self.ttl))
return res
def get(self, key, default=None):
self._populate()
if self.ttl is not None:
# keep the session open with all activity
self._r.expire(self.sessionkey, int(self.ttl))
return super(RedisSession, self).get(key, default)
def _flush(self):
json_dict = json.dumps(self)
self._r.set(self.sessionkey, json_dict)
def _populate(self):
# Get the session data stored as JSON
json_dict = self._r.get(self.sessionkey)
if json_dict is None:
json_dict = '{}'
# Populate the dict()
try:
d = json.loads(json_dict)
except ValueError:
# The stored value wasn't able to be decoded
d = {}
self.update(d)
class ISessionManager(Interface):
def get(sessionkey):
""" Returns a dictionary-like session object.
"""
class RedisSessionManager(object):
implements(ISessionManager)
def __init__(self, host='localhost', port=6379, ttl=3600,
constructor=redis.StrictRedis): # for testing
self._r = constructor(host=host, port=port, db=0)
self.ttl = ttl
def get(self, sessionkey):
return RedisSession(self._r, sessionkey, ttl=self.ttl)
import unittest
from repoze.bfg import testing
class DummyRedisSessionManager(object):
""" we just use the .set() and .get() methods
"""
def __init__(self):
self.res = None
self.expires = []
def get(self, key):
return self.res
def set(self, key, value):
self.res = value
def expire(self, key, ttl):
self.expires.append((key, ttl))
class TestRedisSession(unittest.TestCase):
def setUp(self):
testing.cleanUp()
def tearDown(self):
testing.cleanUp()
def _makeOne(self, redis, sessionkey, ttl=None):
from redis_sessions import RedisSession
return RedisSession(redis, sessionkey, ttl)
def test_flush(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
self.assertEqual(redis.res, None)
session._flush()
self.assertEqual(dict(session), {})
self.assertEqual(redis.res, '{}')
session['key'] = 'value'
session._flush()
self.assertEqual(dict(session), {'key': 'value'})
self.assertEqual(redis.res, '{"key": "value"}')
def test_populate(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
self.assertEqual(redis.res, None)
session._populate()
self.assertEqual(dict(session), {})
redis.res = '{"key": "value"}'
session._populate()
self.assertEqual(dict(session), {'key': 'value'})
def test_getitem(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
self.assertRaises(KeyError, session.__getitem__, 'key')
redis.res = '{"key": "value"}'
self.assertEqual(session['key'], 'value')
def test_get(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
self.assertEqual(session.get('key'), None)
redis.res = '{"key": "value"}'
self.assertEqual(session.get('key'), 'value')
def test_setitem(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
self.assertEqual(session.get('key'), None)
self.assertEqual(redis.res, None)
session['key'] = 'value'
self.assertEqual(session.get('key'), 'value')
self.assertEqual(redis.res, '{"key": "value"}')
def test_delitem(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1')
redis.res = '{"key": "value"}'
self.assertEqual(session.get('key'), 'value')
del(session['key'])
self.assertEqual(session.get('key'), None)
self.assertEqual(redis.res, '{}')
def test_malformed_data(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1', ttl=3600)
redis.res = 'foobar'
self.assertEqual(session.get('key'), None)
self.assertEqual(dict(session), {})
def test_expires(self):
redis = DummyRedisSessionManager()
session = self._makeOne(redis, '1', ttl=3600)
# __setitem__
session['key'] = 'value'
self.assertEqual(len(session._r.expires), 1)
self.assertEqual(session._r.expires.pop(), ('1', 3600))
# __getitem__
self.assertEqual(session['key'], 'value')
self.assertEqual(len(session._r.expires), 1)
self.assertEqual(session._r.expires.pop(), ('1', 3600))
# get
self.assertEqual(session.get('key'), 'value')
self.assertEqual(len(session._r.expires), 1)
self.assertEqual(session._r.expires.pop(), ('1', 3600))
# __delitem__
del(session['key'])
self.assertEqual(len(session._r.expires), 1)
self.assertEqual(session._r.expires.pop(), ('1', 3600))
class TestRedisSessionManager(unittest.TestCase):
def setUp(self):
testing.cleanUp()
def tearDown(self):
testing.cleanUp()
def _makeOne(self, host='localhost', port=6379, ttl=3600, constructor=testing.DummyModel):
from redis_sessions import RedisSessionManager
return RedisSessionManager(host, port, ttl, constructor)
def test_defaults(self):
context = self._makeOne(host='localhost',
port=6379,
ttl=3600,
)
self.assertTrue(isinstance(context._r, testing.DummyModel))
self.assertEqual(context._r.host, 'localhost')
self.assertEqual(context._r.port, 6379)
def test_get(self):
from redis_sessions import RedisSession
context = self._makeOne()
res = context.get('session1')
self.assertTrue(isinstance(res, RedisSession))
@mendelgusmao
Copy link

Maybe you could take advantage of the Redis' hashes and avoid the overhead of JSON serialization.

(I didn't test)

    def _flush(self):
        self._r.hmset(self.sessionkey, self)

    def _populate(self):
        session_data = self._r.hgetall(self.sessionkey)
        self.update(session_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment