Skip to content

Instantly share code, notes, and snippets.

@cooldaemon
Created March 16, 2014 03:18
Show Gist options
  • Save cooldaemon/9578081 to your computer and use it in GitHub Desktop.
Save cooldaemon/9578081 to your computer and use it in GitHub Desktop.
Redis Mutex を Python で実装する ref: http://qiita.com/cooldaemon/items/a192c608a8ead1577881
class MutexError(Exception):
pass
class DuplicateLockError(MutexError):
"""
既に lock() 実行済みの Mutex オブジェクトで lock() を再実行すると発生.
一度, unlock() を実行するか, 別の Mutex オブジェクトを作成する必要がある.
"""
pass
class HasNotLockError(MutexError):
"""
まだ, lock() が実行されていない Mutex オブジェクトで unlock() を実行すると発生.
lock() 後に実行する必要がある.
"""
pass
class ExpiredLockError(MutexError):
"""
lock() 実行後, expire によりロックが解放されている状態で unlock() を実行すると発生.
"""
pass
class SetnxError(MutexError):
pass
class LockError(MutexError):
pass
from datetime import datetime
import time
from functools import wraps
from .exception import (DuplicateLockError,
HasNotLockError,
ExpiredLockError,
SetnxError,
LockError)
class Mutex(object):
def __init__(self, client, key,
expire=10,
retry_count=6, # retry_count * retry_sleep_sec = 最大待ち時間
retry_setnx_count=100,
retry_sleep_sec=0.25):
self._lock = None
self._r = client
self._key = key
self._expire = expire
self._retry_count = retry_count
self._retry_setnx_count = retry_setnx_count
self._retry_sleep_sec = retry_sleep_sec
def _get_now(self):
return float(datetime.now().strftime('%s.%f'))
def lock(self):
if self._lock:
raise DuplicateLockError(self._key)
self._do_lock()
def _do_lock(self):
for n in xrange(0, self._retry_count):
is_set, old_expire = self._setnx()
if is_set:
self._lock = self._get_now()
return
if self._need_retry(old_expire):
continue
if not self._need_retry(self._getset()):
self._lock = self._get_now()
return
raise LockError(self._key)
def _setnx(self):
for n in xrange(0, self._retry_setnx_count):
is_set = self._r.setnx(self._key, self._get_now() + self._expire)
if is_set:
return True, 0
old_expire = self._r.get(self._key)
if old_expire is not None:
return False, float(old_expire)
raise SetnxError(self._key)
def _need_retry(self, expire):
if expire < self._get_now():
return False
time.sleep(self._retry_sleep_sec)
return True
def _getset(self):
old_expire = self._r.getset(self._key, self._get_now() + self._expire)
if old_expire is None:
return 0
return float(old_expire)
def unlock(self):
if not self._lock:
raise HasNotLockError(self._key)
elapsed_time = self._get_now() - self._lock
if self._expire <= elapsed_time:
raise ExpiredLockError(self._key, elapsed_time)
self._r.delete(self._key)
self._lock = None
def __enter__(self):
self.lock()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._lock:
self.unlock()
return True if exc_type is None else False
def __call__(self, func):
@wraps(func)
def inner(*args, **kwargs):
with self:
return func(*args, **kwargs)
return inner
import unittest
import redis
import time
from multiprocessing import Process
from .mutex import Mutex
from .exception import (DuplicateLockError,
HasNotLockError,
ExpiredLockError,
LockError)
class TestMutex(unittest.TestCase):
def setUp(self):
self.key = 'spam'
self.r = redis.StrictRedis()
self.mutex = Mutex(self.r, self.key)
def tearDown(self):
mutex = self.mutex
if mutex._lock:
mutex.unlock()
mutex._r.delete('ham')
def test_lock(self):
mutex = self.mutex
mutex.lock()
self.assertIsNotNone(mutex._r.get(mutex._key))
with self.assertRaises(DuplicateLockError):
mutex.lock()
def test_unlock(self):
self.test_lock()
mutex = self.mutex
self.mutex.unlock()
self.assertIsNone(mutex._r.get(mutex._key))
with self.assertRaises(HasNotLockError):
mutex.unlock()
self.test_lock()
time.sleep(10.5)
with self.assertRaises(ExpiredLockError):
mutex.unlock()
mutex._lock = None # 強制的に初期化
def test_expire(self):
mutex1 = self.mutex
mutex2 = Mutex(self.r, self.key, expire=2)
mutex2.lock() # 2 秒 Lock し続ける
with self.assertRaises(LockError):
mutex1.lock() # retry 6 回 * sleep 0.25 秒 = 1.5 秒
time.sleep(0.6) # おまけ
mutex1.lock()
self.assertIsNotNone(mutex1._r.get(mutex1._key))
def test_with(self):
mutex1 = self.mutex
with mutex1:
self.assertIsNotNone(mutex1._r.get(mutex1._key))
self.assertIsNone(mutex1._r.get(mutex1._key))
mutex2 = Mutex(self.r, self.key, expire=2)
mutex2.lock() # 2 秒 Lock し続ける
with self.assertRaises(LockError):
with mutex1: # retry 6 回 * sleep 0.25 秒 = 1.5 秒
pass
mutex2.unlock()
with mutex1:
with self.assertRaises(DuplicateLockError):
with mutex1:
pass
def test_decorator(self):
mutex = self.mutex
@mutex
def egg():
self.assertIsNotNone(mutex._r.get(mutex._key))
egg()
self.assertIsNone(mutex._r.get(mutex._key))
def test_multi_process(self):
procs = 20
counter = 100
def incr():
mutex = Mutex(redis.StrictRedis(), self.key, retry_count=100)
for n in xrange(0, counter):
mutex.lock()
ham = mutex._r.get('ham') or 0
mutex._r.set('ham', int(ham) + 1)
mutex.unlock()
ps = [Process(target=incr) for n in xrange(0, procs)]
for p in ps:
p.start()
for p in ps:
p.join()
self.assertEqual(int(self.mutex._r.get('ham')), counter * procs)
>>> from mutex import Mutex
>>> with Mutex(':'.join(['EmitAccessToken', user_id]):
>>> # do something ...
>>> pass
>>> @Mutex(':'.join(['EmitAccessToken', user_id]):
>>> def emit_access_token():
>>> # do something ...
>>> pass
>>> mutex = Mutex(':'.join(['EmitAccessToken', user_id])
>>> mutex.lock()
>>> # do something ...
>>> mutex.unlock()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment