Skip to content

Instantly share code, notes, and snippets.

@mattbennett
Last active February 25, 2018 07:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattbennett/4587094f10694c028d29911932a24130 to your computer and use it in GitHub Desktop.
Save mattbennett/4587094f10694c028d29911932a24130 to your computer and use it in GitHub Desktop.
Thread-safe serializer
import time
import uuid
from contextlib import contextmanager
from collections import defaultdict
from threading import local, Thread, Lock, current_thread
import wrapt
class MethodProxy(wrapt.ObjectProxy):
def __init__(self, wrapped, lock):
self._self_lock = lock
super(MethodProxy, self).__init__(wrapped)
def __call__(self, *args, **kwargs):
with self._self_lock:
return self.__wrapped__(*args, **kwargs)
class ThreadSafeWrapper(wrapt.ObjectProxy):
locks = defaultdict(Lock)
def __init__(self, wrapped):
self._self_ident = uuid.uuid4()
super(ThreadSafeWrapper, self).__init__(wrapped)
@contextmanager
def serialize(self, ident):
with ThreadSafeWrapper.locks[ident]:
yield self
def __getattr__(self, name):
attr = getattr(self.__wrapped__, name)
lock = ThreadSafeWrapper.locks[self._self_ident]
return MethodProxy(attr, lock)
# ---
import random
import pytest
class TestThreadSerializer(object):
@pytest.fixture
def targets(self):
return defaultdict(list)
@pytest.fixture
def unsafe(self, targets):
class NotThreadSafe(object):
def __init__(self):
self.passthrough = None
self.state = None
def action(self, value):
target = targets[value]
self.state = value
time.sleep(random.random() / 100)
target.append(self.state)
return self.state
return NotThreadSafe()
@pytest.fixture
def unsafe_non_atomic_methods(self, targets):
class NotThreadSafe(object):
def __init__(self):
self.state = None
def step1(self, value):
self.state = value
return self.state
def step2(self, value):
target = targets[value]
target.append(self.state)
return self.state
return NotThreadSafe()
def test_not_thread_safe(self, unsafe, targets):
iterations = 10
def insert(value):
for _ in range(iterations):
unsafe.action(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] != ["foo"] * iterations
assert targets["bar"] != ["bar"] * iterations
def test_methods_serialized(self, unsafe, targets):
safe = ThreadSafeWrapper(unsafe)
iterations = 10
def insert(value):
for _ in range(iterations):
assert safe.action(value) == value
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == ["foo"] * iterations
assert targets["bar"] == ["bar"] * iterations
def test_attr_access_passthrough(self, unsafe):
safe = ThreadSafeWrapper(unsafe)
safe.passthrough = "value"
assert unsafe.passthrough == "value"
def test_not_thread_safe_nonatomic_methods(
self, unsafe_non_atomic_methods, targets
):
unsafe = unsafe_non_atomic_methods
iterations = 10
def insert(value):
for _ in range(iterations):
unsafe.step1(value)
time.sleep(random.random() / 100)
unsafe.step2(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] != ["foo"] * iterations
assert targets["bar"] != ["bar"] * iterations
def test_still_not_thread_safe_nonatomic_methods_serialized(
self, unsafe_non_atomic_methods, targets
):
safe = ThreadSafeWrapper(unsafe_non_atomic_methods)
iterations = 10
def insert(value):
for _ in range(iterations):
safe.step1(value)
time.sleep(random.random() / 100)
safe.step2(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] != ["foo"] * iterations
assert targets["bar"] != ["bar"] * iterations
def test_nonatomic_methods_serialized(
self, unsafe_non_atomic_methods, targets
):
unsafe = unsafe_non_atomic_methods
iterations = 10
def insert(value):
for _ in range(iterations):
with ThreadSafeWrapper(unsafe).serialize('lock-ident') as safe:
safe.step1(value)
time.sleep(random.random() / 100)
safe.step2(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == ["foo"] * iterations
assert targets["bar"] == ["bar"] * iterations
def test_nonatomic_methods_serialized_alt(
self, unsafe_non_atomic_methods, targets
):
unsafe = unsafe_non_atomic_methods
iterations = 10
# wrap non-atomic methods in another method and serialize that
class Atomic:
def atomic(self, value):
unsafe.step1(value)
time.sleep(random.random() / 100)
unsafe.step2(value)
safe = ThreadSafeWrapper(Atomic())
def insert(value):
for _ in range(iterations):
safe.atomic(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == ["foo"] * iterations
assert targets["bar"] == ["bar"] * iterations
def test_not_thread_safe_nonatomic_functions(self, targets):
state = [None]
def step1(value):
state[0] = value
return state[0]
def step2(value):
target = targets[value]
target.append(state[0])
return state[0]
iterations = 10
def insert(value):
for _ in range(iterations):
step1(value)
time.sleep(random.random() / 100)
step2(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] != ["foo"] * iterations
assert targets["bar"] != ["bar"] * iterations
def test_nonatomic_functions_serialized(self, targets):
state = [None]
def step1(value):
state[0] = value
return state[0]
def step2(value):
target = targets[value]
target.append(state[0])
return state[0]
iterations = 10
# nb. it's important that 'lock-ident' is common across all things
# that you need to synchronize
def insert(value):
for _ in range(iterations):
with ThreadSafeWrapper(None).serialize('lock-ident'):
step1(value)
time.sleep(random.random() / 100)
step2(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == ["foo"] * iterations
assert targets["bar"] == ["bar"] * iterations
def test_nonatomic_functions_serialized_alt(self, targets):
state = [None]
def step1(value):
state[0] = value
return state[0]
def step2(value):
target = targets[value]
target.append(state[0])
return state[0]
# wrap non-atomic functions in another function and serialize that
class Atomic(object):
def atomic(self, value):
step1(value)
time.sleep(random.random() / 100)
step2(value)
safe = ThreadSafeWrapper(Atomic())
iterations = 10
def insert(value):
for _ in range(iterations):
safe.atomic(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == ["foo"] * iterations
assert targets["bar"] == ["bar"] * iterations
def test_thread_locals(self, targets):
iterations = 10
class UsesThreadLocals(object):
def __init__(self):
self.data = local()
def action(self, value):
target = targets[value]
self.data.state = current_thread().ident
time.sleep(random.random() / 100)
target.append(self.data.state)
return self.data.state
unsafe = UsesThreadLocals()
safe = ThreadSafeWrapper(unsafe)
def insert(value):
for _ in range(iterations):
safe.action(value)
foo_thread = Thread(target=insert, args=("foo",))
bar_thread = Thread(target=insert, args=("bar",))
foo_thread.start()
bar_thread.start()
foo_thread.join()
bar_thread.join()
assert targets["foo"] == [foo_thread.ident] * iterations
assert targets["bar"] == [bar_thread.ident] * iterations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment