-
-
Save mattbennett/4587094f10694c028d29911932a24130 to your computer and use it in GitHub Desktop.
Thread-safe serializer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import uuid | |
from contextlib import contextmanager | |
from collections import defaultdict | |
from threading import local, Thread, Lock, current_thread | |
import wrapt | |
class MethodProxy(wrapt.ObjectProxy): | |
def __init__(self, wrapped, lock): | |
self._self_lock = lock | |
super(MethodProxy, self).__init__(wrapped) | |
def __call__(self, *args, **kwargs): | |
with self._self_lock: | |
return self.__wrapped__(*args, **kwargs) | |
class ThreadSafeWrapper(wrapt.ObjectProxy): | |
locks = defaultdict(Lock) | |
def __init__(self, wrapped): | |
self._self_ident = uuid.uuid4() | |
super(ThreadSafeWrapper, self).__init__(wrapped) | |
@contextmanager | |
def serialize(self, ident): | |
with ThreadSafeWrapper.locks[ident]: | |
yield self | |
def __getattr__(self, name): | |
attr = getattr(self.__wrapped__, name) | |
lock = ThreadSafeWrapper.locks[self._self_ident] | |
return MethodProxy(attr, lock) | |
# --- | |
import random | |
import pytest | |
class TestThreadSerializer(object): | |
@pytest.fixture | |
def targets(self): | |
return defaultdict(list) | |
@pytest.fixture | |
def unsafe(self, targets): | |
class NotThreadSafe(object): | |
def __init__(self): | |
self.passthrough = None | |
self.state = None | |
def action(self, value): | |
target = targets[value] | |
self.state = value | |
time.sleep(random.random() / 100) | |
target.append(self.state) | |
return self.state | |
return NotThreadSafe() | |
@pytest.fixture | |
def unsafe_non_atomic_methods(self, targets): | |
class NotThreadSafe(object): | |
def __init__(self): | |
self.state = None | |
def step1(self, value): | |
self.state = value | |
return self.state | |
def step2(self, value): | |
target = targets[value] | |
target.append(self.state) | |
return self.state | |
return NotThreadSafe() | |
def test_not_thread_safe(self, unsafe, targets): | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
unsafe.action(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] != ["foo"] * iterations | |
assert targets["bar"] != ["bar"] * iterations | |
def test_methods_serialized(self, unsafe, targets): | |
safe = ThreadSafeWrapper(unsafe) | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
assert safe.action(value) == value | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == ["foo"] * iterations | |
assert targets["bar"] == ["bar"] * iterations | |
def test_attr_access_passthrough(self, unsafe): | |
safe = ThreadSafeWrapper(unsafe) | |
safe.passthrough = "value" | |
assert unsafe.passthrough == "value" | |
def test_not_thread_safe_nonatomic_methods( | |
self, unsafe_non_atomic_methods, targets | |
): | |
unsafe = unsafe_non_atomic_methods | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
unsafe.step1(value) | |
time.sleep(random.random() / 100) | |
unsafe.step2(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] != ["foo"] * iterations | |
assert targets["bar"] != ["bar"] * iterations | |
def test_still_not_thread_safe_nonatomic_methods_serialized( | |
self, unsafe_non_atomic_methods, targets | |
): | |
safe = ThreadSafeWrapper(unsafe_non_atomic_methods) | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
safe.step1(value) | |
time.sleep(random.random() / 100) | |
safe.step2(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] != ["foo"] * iterations | |
assert targets["bar"] != ["bar"] * iterations | |
def test_nonatomic_methods_serialized( | |
self, unsafe_non_atomic_methods, targets | |
): | |
unsafe = unsafe_non_atomic_methods | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
with ThreadSafeWrapper(unsafe).serialize('lock-ident') as safe: | |
safe.step1(value) | |
time.sleep(random.random() / 100) | |
safe.step2(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == ["foo"] * iterations | |
assert targets["bar"] == ["bar"] * iterations | |
def test_nonatomic_methods_serialized_alt( | |
self, unsafe_non_atomic_methods, targets | |
): | |
unsafe = unsafe_non_atomic_methods | |
iterations = 10 | |
# wrap non-atomic methods in another method and serialize that | |
class Atomic: | |
def atomic(self, value): | |
unsafe.step1(value) | |
time.sleep(random.random() / 100) | |
unsafe.step2(value) | |
safe = ThreadSafeWrapper(Atomic()) | |
def insert(value): | |
for _ in range(iterations): | |
safe.atomic(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == ["foo"] * iterations | |
assert targets["bar"] == ["bar"] * iterations | |
def test_not_thread_safe_nonatomic_functions(self, targets): | |
state = [None] | |
def step1(value): | |
state[0] = value | |
return state[0] | |
def step2(value): | |
target = targets[value] | |
target.append(state[0]) | |
return state[0] | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
step1(value) | |
time.sleep(random.random() / 100) | |
step2(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] != ["foo"] * iterations | |
assert targets["bar"] != ["bar"] * iterations | |
def test_nonatomic_functions_serialized(self, targets): | |
state = [None] | |
def step1(value): | |
state[0] = value | |
return state[0] | |
def step2(value): | |
target = targets[value] | |
target.append(state[0]) | |
return state[0] | |
iterations = 10 | |
# nb. it's important that 'lock-ident' is common across all things | |
# that you need to synchronize | |
def insert(value): | |
for _ in range(iterations): | |
with ThreadSafeWrapper(None).serialize('lock-ident'): | |
step1(value) | |
time.sleep(random.random() / 100) | |
step2(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == ["foo"] * iterations | |
assert targets["bar"] == ["bar"] * iterations | |
def test_nonatomic_functions_serialized_alt(self, targets): | |
state = [None] | |
def step1(value): | |
state[0] = value | |
return state[0] | |
def step2(value): | |
target = targets[value] | |
target.append(state[0]) | |
return state[0] | |
# wrap non-atomic functions in another function and serialize that | |
class Atomic(object): | |
def atomic(self, value): | |
step1(value) | |
time.sleep(random.random() / 100) | |
step2(value) | |
safe = ThreadSafeWrapper(Atomic()) | |
iterations = 10 | |
def insert(value): | |
for _ in range(iterations): | |
safe.atomic(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == ["foo"] * iterations | |
assert targets["bar"] == ["bar"] * iterations | |
def test_thread_locals(self, targets): | |
iterations = 10 | |
class UsesThreadLocals(object): | |
def __init__(self): | |
self.data = local() | |
def action(self, value): | |
target = targets[value] | |
self.data.state = current_thread().ident | |
time.sleep(random.random() / 100) | |
target.append(self.data.state) | |
return self.data.state | |
unsafe = UsesThreadLocals() | |
safe = ThreadSafeWrapper(unsafe) | |
def insert(value): | |
for _ in range(iterations): | |
safe.action(value) | |
foo_thread = Thread(target=insert, args=("foo",)) | |
bar_thread = Thread(target=insert, args=("bar",)) | |
foo_thread.start() | |
bar_thread.start() | |
foo_thread.join() | |
bar_thread.join() | |
assert targets["foo"] == [foo_thread.ident] * iterations | |
assert targets["bar"] == [bar_thread.ident] * iterations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment