Skip to content

Instantly share code, notes, and snippets.

@sunilmallya
Created March 5, 2015 23:20
Show Gist options
  • Save sunilmallya/6cb40dc4b9761a800057 to your computer and use it in GitHub Desktop.
Save sunilmallya/6cb40dc4b9761a800057 to your computer and use it in GitHub Desktop.
Wrapper class that supports both synchronous and asynchronous redis client
'''
A useful DB Connection class with both sync & async redis clients
It uses a threadpool to make the redis library asynchronous so as
to work with tornado seamlessly. There is also a retry wrapper built in
to retry in case of connection failures to redis server
Tornado 4.0 required, for the rest of the requirements check the imports
'''
import concurrent.futures
import logging
import multiprocessing
import os
import redis as blockingRedis
import time
import tornado.ioloop
import tornado.gen
import tornado.httpclient
import tornado.web
from tornado.httpclient import HTTPResponse, HTTPRequest
from tornado.options import define, options
import threading
# testing
from mock import patch, MagicMock
import tornado.testing
import unittest
_log = logging.getLogger(__name__)
define("redisDB", default="127.0.0.1", help="Main DB")
define("dbPort", default=6379, help="Main DB")
define("maxRedisRetries", default=3, help="")
define("baseRedisRetryWait", default=5, help="")
class DBStateError(ValueError):pass
class DBConnection(object):
'''Connection to the database.
There is one connection for each object type, so to get the
connection, please use the get() function and don't create it
directly.
db_conn = DBConnection.get(self) OR
db_conn = DBConnection.get(cls)
'''
#Note: Lock for each instance, currently locks for any instance creation
__singleton_lock = threading.Lock()
_singleton_instance = {}
def __init__(self, class_name):
'''Init function.
DO NOT CALL THIS DIRECTLY. Use the get() function instead
'''
host = options.redisDB
port = options.dbPort
# NOTE: You can add conditionals here to connect to different
# redis servers if its shared on class names
self.conn, self.blocking_conn = RedisClient.get_client(host, port)
def fetch_keys_from_db(self, key_prefix, callback=None):
''' fetch keys that match a prefix '''
if callback:
self.conn.keys(key_prefix, callback)
else:
keys = self.blocking_conn.keys(key_prefix)
return keys
def clear_db(self):
'''Erases all the keys in the database.
This should really only be used in test scenarios.
'''
self.blocking_conn.flushdb()
@classmethod
def update_instance(cls, cname):
''' Method to update the connection object in case of
db config update '''
if cls._singleton_instance.has_key(cname):
with cls.__singleton_lock:
if cls._singleton_instance.has_key(cname):
cls._singleton_instance[cname] = cls(cname)
@classmethod
def get(cls, otype=None):
'''Gets a DB connection for a given object type.
otype - The object type to get the connection for.
Can be a class object, an instance object or the class name
as a string.
'''
cname = None
if otype:
if isinstance(otype, basestring):
cname = otype
else:
#handle the case for classmethod
cname = otype.__class__.__name__ \
if otype.__class__.__name__ != "type" else otype.__name__
if not cls._singleton_instance.has_key(cname):
with cls.__singleton_lock:
if not cls._singleton_instance.has_key(cname):
cls._singleton_instance[cname] = \
DBConnection(cname)
return cls._singleton_instance[cname]
@classmethod
def clear_singleton_instance(cls):
'''
Clear the singleton instance for each of the classes
NOTE: To be only used by the test code
'''
cls._singleton_instance = {}
class RedisRetryWrapper(object):
'''Wraps a redis client so that it retries with exponential backoff.
You use this class exactly the same way that you would use the
StrctRedis class.
Calls on this object are blocking.
'''
def __init__(self, *args, **kwargs):
self.client = blockingRedis.StrictRedis(*args, **kwargs)
self.max_tries = options.maxRedisRetries
self.base_wait = options.baseRedisRetryWait
def _get_wrapped_retry_func(self, func):
'''Returns an blocking retry function wrapped around the given func.
'''
def RetryWrapper(*args, **kwargs):
cur_try = 0
while True:
try:
return func(*args, **kwargs)
except Exception as e:
_log.error('Error talking to redis on attempt %i: %s' %
(cur_try, e))
cur_try += 1
if cur_try == self.max_tries:
raise
# Do an exponential backoff
delay = (1 << cur_try) * self.base_wait # in seconds
time.sleep(delay)
return RetryWrapper
def __getattr__(self, attr):
'''Allows us to wrap all of the redis-py functions.'''
if hasattr(self.client, attr):
if hasattr(getattr(self.client, attr), '__call__'):
return self._get_wrapped_retry_func(
getattr(self.client, attr))
raise AttributeError(attr)
class RedisAsyncWrapper(object):
'''
Replacement class for tornado-redis
This is a wrapper class which does redis operation
in a background thread and on completion transfers control
back to the tornado ioloop. If you wrap this around gen/Task,
you can write db operations as if they were synchronous.
usage:
value = yield tornado.gen.Task(RedisAsyncWrapper().get, key)
#TODO: see if we can completely wrap redis-py calls, helpful if
you can get the callback attribue as well when call is made
'''
_thread_pools = {}
_pool_lock = multiprocessing.RLock()
_async_pool_size = 10
def __init__(self, host='127.0.0.1', port=6379):
self.client = blockingRedis.StrictRedis(host, port, socket_timeout=10)
self.max_tries = options.maxRedisRetries
self.base_wait = options.baseRedisRetryWait
@classmethod
def _get_thread_pool(cls):
'''Get the thread pool for this process.'''
with cls._pool_lock:
try:
return cls._thread_pools[os.getpid()]
except KeyError:
pool = concurrent.futures.ThreadPoolExecutor(
cls._async_pool_size)
cls._thread_pools[os.getpid()] = pool
return pool
def _get_wrapped_async_func(self, func):
'''Returns an asynchronous function wrapped around the given func.
The asynchronous call has a callback keyword added to it
'''
def AsyncWrapper(*args, **kwargs):
# Find the callback argument
try:
callback = kwargs['callback']
del kwargs['callback']
except KeyError:
if len(args) > 0 and hasattr(args[-1], '__call__'):
callback = args[-1]
args = args[:-1]
else:
raise AttributeError('A callback is necessary')
io_loop = tornado.ioloop.IOLoop.current()
def _cb(future, cur_try=0):
if future.exception() is None:
callback(future.result())
else:
_log.error('Error talking to redis on attempt %i: %s' %
(cur_try, future.exception()))
cur_try += 1
if cur_try == self.max_tries:
raise future.exception()
delay = (1 << cur_try) * self.base_wait # in seconds
io_loop.add_timeout(
time.time() + delay,
lambda: io_loop.add_future(
RedisAsyncWrapper._get_thread_pool().submit(
func, *args, **kwargs),
lambda x: _cb(x, cur_try)))
future = RedisAsyncWrapper._get_thread_pool().submit(
func, *args, **kwargs)
io_loop.add_future(future, _cb)
return AsyncWrapper
def __getattr__(self, attr):
'''Allows us to wrap all of the redis-py functions.'''
if hasattr(self.client, attr):
if hasattr(getattr(self.client, attr), '__call__'):
return self._get_wrapped_async_func(
getattr(self.client, attr))
raise AttributeError(attr)
class RedisClient(object):
'''
Static class for REDIS configuration
'''
#static variables
host = '127.0.0.1'
port = 6379
client = None
blocking_client = None
def __init__(self, host='127.0.0.1', port=6379):
self.client = RedisAsyncWrapper(host, port)
self.blocking_client = RedisRetryWrapper(host, port)
@staticmethod
def get_client(host=None, port=None):
'''
return connection objects (blocking and non blocking)
'''
if host is None:
host = RedisClient.host
if port is None:
port = RedisClient.port
RedisClient.c = RedisAsyncWrapper(host, port)
RedisClient.bc = RedisRetryWrapper(
host, port, socket_timeout=10)
return RedisClient.c, RedisClient.bc
class DBObject(object):
'''
Abstract class to represent an object that is stored in the database.
Note: You can make this more abstract by ddding your own
serializers and deserializers to insert python objects in to the DB
and create the object on retrieval
# NOTE: Only get & save methods are currently implemented, but you get
the basic idea, right...?
'''
def __init__(self, key):
self.key = str(key)
def save(self, callback=None):
'''
Save the object to the database.
'''
db_connection = DBConnection.get(self)
value = self.get_value()
if callback:
db_connection.conn.set(self.key, value, callback)
else:
return db_connection.blocking_conn.set(self.key, value)
@classmethod
def get(cls, key, callback=None):
db_connection = DBConnection.get(cls)
if callback:
db_connection.conn.get(key, callback)
else:
data = db_connection.blocking_conn.get(key)
return data
def get_value(self):
raise NotImplementedError()
################# TESTS ##########################
# Lil test class
class Foo(DBObject):
'''
# You get the idea now, You can make the abstract class more generic
'''
def __init__(self, key, val):
self.value = val
super(Foo, self).__init__(key)
def get_value(self):
return self.value
################ Sanity tester ######################
def test_async():
'''
Async sanity test code
'''
def callback(val):
print "async ", val
Foo.get('test', callback)
tornado.ioloop.IOLoop.current().start()
# It'll block here, you'll need to ^C or SIGTERM to exit
################## Sample unit Test ####################
class TestHandler(tornado.web.RequestHandler):
def initialize(self):
pass
class TestAsyncDBConnection(tornado.testing.AsyncHTTPTestCase,
unittest.TestCase):
'''
more info: http://tornado.readthedocs.org/en/latest/testing.html
'''
def setUp(self):
super(TestAsyncDBConnection, self).setUp()
def tearDown(self):
DBConnection.clear_singleton_instance()
super(TestAsyncDBConnection, self).tearDown()
def get_app(self):
return tornado.web.Application([(r'/', TestHandler)])
@tornado.testing.gen_test
def test_async_get(self):
foo = Foo('unittestfoo', 'bar')
foo.save()
val = yield tornado.gen.Task(Foo.get, "unittestfoo")
self.assertEqual(val, 'bar')
# .... write more tests as you please
if __name__ == "__main__":
# Assumes redis is running on localhost on 6379
foo = Foo('test', 'testval')
foo.save()
val = Foo.get('test')
assert val == 'testval'
# Feel free to write more test code and play with this gist as necessary
# The main motive here is to provide a framework to work with
# sanity test (async)
#test_async()
# Unit test
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment