Last active
June 14, 2023 08:45
-
-
Save Joshuaalbert/36dabe4f7f9648763520d19e57fcce22 to your computer and use it in GitHub Desktop.
Fair AsyncIO RLock implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from collections import deque | |
class FairAsyncRLock: | |
""" | |
A fair reentrant lock for async programming. Fair means that it respects the order of acquisition. | |
""" | |
def __init__(self): | |
self._owner: asyncio.Task | None = None | |
self._count = 0 | |
self._queue = deque() | |
def is_owner(self, task=None): | |
if task is None: | |
task = asyncio.current_task() | |
return self._owner == task | |
async def acquire(self): | |
"""Acquire the lock.""" | |
me = asyncio.current_task() | |
if self.is_owner(task=me): | |
self._count += 1 | |
return | |
# If the lock is free or reentrant, acquire it immediately | |
if self._count == 0: | |
# if self._count == 0 or self._owner == me: (redundant second clause) | |
self._owner = me | |
self._count += 1 | |
else: | |
# Create an event for this task | |
event = asyncio.Event() | |
self._queue.append(event) | |
# Wait for the lock to be free | |
try: | |
await event.wait() | |
except asyncio.CancelledError: | |
self._queue.remove(event) | |
raise | |
self._owner = me | |
self._count = 1 | |
async def release(self): | |
"""Release the lock""" | |
me = asyncio.current_task() | |
if self._owner is None: | |
raise RuntimeError("Cannot release un-acquired lock.") | |
if not self.is_owner(task=me): | |
raise RuntimeError("Cannot release foreign lock.") | |
self._count -= 1 | |
if self._count == 0: | |
self._owner = None | |
if self._queue: | |
# Wake up the next task in the queue | |
event = self._queue.popleft() | |
event.set() | |
async def __aenter__(self): | |
await self.acquire() | |
return self | |
async def __aexit__(self, exc_type, exc, tb): | |
await self.release() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from time import monotonic_ns, perf_counter | |
import pytest | |
from fair_async_rlock import FairAsyncRLock | |
@pytest.mark.asyncio | |
async def test_reentrant(): | |
lock = FairAsyncRLock() | |
async with lock: | |
async with lock: # This should not block | |
assert True | |
@pytest.mark.asyncio | |
async def test_exclusion(): | |
lock = FairAsyncRLock() | |
got_in = False | |
async def inner(): | |
nonlocal got_in | |
async with lock: | |
got_in = True | |
# Acquire the lock, then run the inner task. It shouldn't be able | |
# to acquire the lock. | |
async with lock: | |
asyncio.create_task(inner()) | |
await asyncio.sleep(0) # Give the inner task a chance to run | |
assert not got_in | |
@pytest.mark.asyncio | |
async def test_fairness(): | |
lock = FairAsyncRLock() | |
order = [] | |
async def worker(n): | |
async with lock: | |
order.append(n) | |
# Start several tasks to acquire the lock | |
tasks = [asyncio.create_task(worker(i)) for i in range(5)] | |
# Make sure they all start and try to acquire the lock before releasing it | |
await asyncio.sleep(0) | |
async with lock: | |
pass # Release the lock | |
await asyncio.gather(*tasks) | |
assert order == list(range(5)) # The tasks should have run in order | |
@pytest.mark.asyncio | |
async def test_unowned_release(): | |
lock = FairAsyncRLock() | |
with pytest.raises(RuntimeError, match="Cannot release un-acquired lock."): | |
await lock.release() | |
async def worker(): | |
with pytest.raises(RuntimeError, match="Cannot release un-acquired lock."): | |
await lock.release() | |
await asyncio.gather(worker()) | |
@pytest.mark.asyncio | |
async def test_performance(): | |
# This test is useful for measuring the overhead of the locking mechanism and can help determine whether it's | |
# suitable for high-concurrency scenarios. | |
lock = FairAsyncRLock() | |
num_tasks = 1000 | |
order = [] | |
async def worker(n): | |
async with lock: | |
order.append(n) | |
tasks = [asyncio.create_task(worker(i)) for i in range(num_tasks)] | |
start = monotonic_ns() | |
await asyncio.gather(*tasks) | |
end = monotonic_ns() | |
print(f"Time to complete {num_tasks} tasks: {end - start} ns") | |
assert order == list(range(num_tasks)) # The tasks should have run in order | |
@pytest.mark.asyncio | |
async def test_stress(): | |
# We'll create a large number of tasks that all try to acquire and release the lock repeatedly. | |
# This can help identify any issues that only occur under high load or after many operations. | |
lock = FairAsyncRLock() | |
num_tasks = 100 | |
iterations = 1000 | |
async def worker(n): | |
for _ in range(iterations): | |
async with lock: | |
pass | |
tasks = [asyncio.create_task(worker(i)) for i in range(num_tasks)] | |
await asyncio.gather(*tasks) | |
@pytest.mark.asyncio | |
async def test_hard(): | |
# "Hard" Test: We'll create a scenario where tasks are constantly being created and cancelled, | |
# while trying to acquire the lock. This can help identify any issues related to task cancellation and cleanup. | |
lock = FairAsyncRLock() | |
num_tasks = 100 | |
iterations = 1000 | |
async def worker(n): | |
for _ in range(iterations): | |
async with lock: | |
if n % 10 == 0: # Cancel every 10th task | |
raise asyncio.CancelledError | |
tasks = [asyncio.create_task(worker(i)) for i in range(num_tasks)] | |
with pytest.raises(asyncio.CancelledError): | |
await asyncio.gather(*tasks) | |
assert lock._count == 0 # The lock should be released after all tasks are done | |
assert lock._owner is None | |
@pytest.mark.asyncio | |
async def test_lock_status_checks(): | |
# We should add tests to validate the is_owner method in the FairAsyncRLock class. | |
# This method is crucial as it determines whether a lock can be acquired or released by the current task. | |
lock = FairAsyncRLock() | |
# The lock should not have an owner initially | |
assert not lock.is_owner() | |
# After acquiring the lock, it should be owned by the current task | |
async with lock: | |
assert lock.is_owner() | |
@pytest.mark.asyncio | |
async def test_nested_lock_acquisition(): | |
# While reentrancy was tested, it was not tested in a nested scenario involving more than one task. | |
# We can design a test case where multiple tasks try to acquire a lock which is already owned by a task | |
# that is itself waiting for another lock. This tests the behavior of the FairAsyncRLock in nested lock | |
# acquisition scenarios. | |
lock1 = FairAsyncRLock() | |
lock2 = FairAsyncRLock() | |
lock1_acquired = asyncio.Event() | |
async def worker(): | |
async with lock1: | |
lock1_acquired.set() # Signal that lock1 has been acquired | |
await asyncio.sleep(0) # Yield control while holding lock1 | |
# At this point, lock1 is released | |
async def control_task(): | |
task = asyncio.create_task(worker()) | |
await lock1_acquired.wait() # Wait for worker to acquire lock1 | |
assert lock1.is_owner(task=task) # worker task should own lock1 | |
async with lock2: # Acquire lock2 | |
assert lock1.is_owner(task=task) # worker task should still own lock1 | |
await task # Await completion of worker task after lock2 is released | |
await control_task() | |
@pytest.mark.asyncio | |
async def test_starvation(): | |
# While fairness was tested, starvation, where a low-priority task could potentially be waiting forever | |
# while higher-priority tasks continuously acquire the lock, is not explicitly covered. The design of the | |
# FairAsyncRLock should prevent this from happening, but it could be worthwhile to add a test case that | |
# specifically tests for this condition. | |
lock = FairAsyncRLock() | |
order = [] | |
async def worker(n): | |
async with lock: | |
order.append(n) | |
# Start a low-priority task | |
low_priority_task = asyncio.create_task(worker(0)) | |
# Give it a moment to start | |
await asyncio.sleep(0) | |
# Start several high-priority tasks | |
high_priority_tasks = [asyncio.create_task(worker(i)) for i in range(1, 10)] | |
# Wait for all tasks to complete | |
await low_priority_task | |
await asyncio.gather(*high_priority_tasks) | |
# Check that the low-priority task was able to acquire the lock | |
assert 0 in order | |
@pytest.mark.asyncio | |
async def test_concurrent_acquisition(): | |
lock = FairAsyncRLock() | |
result = [] | |
async def worker(n): | |
await lock.acquire() # This will block until the lock can be acquired | |
result.append(n) | |
await asyncio.sleep(0) # Yield control | |
await lock.release() | |
# Start several tasks concurrently | |
tasks = [asyncio.create_task(worker(i)) for i in range(5)] | |
await asyncio.gather(*tasks) | |
# All tasks should have been able to acquire the lock, but only one at a time | |
assert len(result) == 5 | |
@pytest.mark.asyncio | |
async def test_performance_comparison(): | |
fair_lock = FairAsyncRLock() | |
asyncio_lock = asyncio.Lock() | |
num_tasks = 1000 | |
async def worker(lock): | |
async with lock: | |
await asyncio.sleep(0) # Simulate some work | |
# Measure performance of FairAsyncRLock | |
fair_tasks = [asyncio.create_task(worker(fair_lock)) for _ in range(num_tasks)] | |
start_fair = perf_counter() | |
await asyncio.gather(*fair_tasks) | |
duration_fair = perf_counter() - start_fair | |
# Measure performance of asyncio.Lock | |
asyncio_tasks = [asyncio.create_task(worker(asyncio_lock)) for _ in range(num_tasks)] | |
start_asyncio = perf_counter() | |
await asyncio.gather(*asyncio_tasks) | |
duration_asyncio = perf_counter() - start_asyncio | |
print(f"Time to complete {num_tasks} tasks with FairAsyncRLock: {duration_fair} seconds") | |
print(f"Time to complete {num_tasks} tasks with asyncio.Lock: {duration_asyncio} seconds") | |
# We find that it's about the same performance as asyncio.Lock. | |
@pytest.mark.asyncio | |
async def test_lock_released_on_exception(): | |
lock = FairAsyncRLock() | |
with pytest.raises(Exception): | |
async with lock: | |
raise Exception("Test") | |
assert lock._count == 0 | |
assert lock._owner is None | |
@pytest.mark.asyncio | |
async def test_release_foreign_lock(): | |
lock = FairAsyncRLock() | |
async def task1(): | |
async with lock: | |
await asyncio.sleep(0.1) # Sleep to ensure that task2 gets to the point where it's waiting for the lock | |
async def task2(): | |
# Wait for both tasks to complete | |
try: | |
await lock.release() | |
except RuntimeError as e: | |
assert str(e) == "Cannot release foreign lock." | |
return | |
# Create the tasks and schedule them | |
task1_handle = asyncio.create_task(task1()) | |
task2_handle = asyncio.create_task(task2()) | |
# Wait for both tasks to complete | |
await asyncio.gather(task1_handle, task2_handle) | |
@pytest.mark.asyncio | |
async def test_lock_acquired_released_normally(): | |
lock = FairAsyncRLock() | |
async with lock: | |
assert lock._count == 1 | |
assert lock._owner is not None | |
assert lock._owner == asyncio.current_task() | |
assert lock._owner is None | |
assert lock._count == 0 | |
@pytest.mark.asyncio | |
async def test_async_release(): | |
# This test checks if the release() method works correctly when turned into an async function. | |
# It creates two tasks that sequentially acquire and release the lock, ensuring that the second task can | |
# acquire the lock after the first one has released it. | |
lock = FairAsyncRLock() | |
async def task1(): | |
async with lock: | |
await asyncio.sleep(0.1) | |
async def task2(): | |
async with lock: | |
pass | |
task1 = asyncio.create_task(task1()) | |
task2 = asyncio.create_task(task2()) | |
await asyncio.gather(task1, task2) | |
# Ensure that lock is not owned and queue is empty after tasks are done | |
assert lock._owner is None | |
assert len(lock._queue) == 0 | |
@pytest.mark.asyncio | |
async def test_acquire_exception_handling(): | |
# We can simulate an exception occurring in the acquire() method and validate that it does not leave the | |
# lock in an inconsistent state. | |
lock = FairAsyncRLock() | |
async def failing_task(): | |
try: | |
await lock.acquire() | |
raise RuntimeError("Simulated exception during acquire") | |
except: | |
await lock.release() | |
raise | |
async def succeeding_task(): | |
await lock.acquire() | |
await lock.release() | |
task1 = asyncio.create_task(failing_task()) | |
task2 = asyncio.create_task(succeeding_task()) | |
with pytest.raises(RuntimeError, match="Simulated exception during acquire"): | |
await asyncio.gather(task1, task2) | |
# Ensure that lock is not owned and queue is empty after exception | |
assert lock._owner is None | |
assert len(lock._queue) == 0 | |
@pytest.mark.asyncio | |
async def test_task_cancellation(): | |
# We need to verify that if a task is cancelled while waiting for the lock, it gets removed from the queue. | |
lock = FairAsyncRLock() | |
async def task1(): | |
await lock.acquire() | |
await asyncio.sleep(0.1) # Let's ensure the lock is held for a bit | |
await lock.release() | |
async def task2(): | |
await lock.acquire() | |
task1 = asyncio.create_task(task1()) | |
task2 = asyncio.create_task(task2()) | |
await asyncio.sleep(0) # Yield control to allow tasks to start | |
task2.cancel() | |
with pytest.raises(asyncio.CancelledError): | |
await task2 | |
await task1 # Ensure task1 has a chance to release the lock | |
# Ensure that lock is not owned and queue is empty after cancellation | |
assert lock._owner is None | |
assert len(lock._queue) == 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment