Last active
April 19, 2024 04:48
-
-
Save thegamecracks/831603c4adac9cab94b4274846b59e0f to your computer and use it in GitHub Desktop.
A thread-safe event compatible with asyncio
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requires Python>=3.11 | |
import asyncio | |
import contextlib | |
import threading | |
import unittest | |
class Event(threading.Event): # TODO: any better name? | |
"""A thread-safe event compatible with asyncio.""" | |
_cond: threading.Condition | |
_flag: bool | |
_waiters: list[asyncio.Future[bool]] | |
def __init__(self) -> None: | |
super().__init__() | |
self._waiters = [] | |
# FIXME: might be safer to vendor threading.Event instead | |
assert hasattr(self, "_cond"), "_cond missing from threading.Event" | |
assert hasattr(self, "_flag"), "_flag missing from threading.Event" | |
def set(self): | |
with self._cond: | |
self._flag = True | |
self._cond.notify_all() | |
self._notify_waiters() | |
async def wait_async(self) -> bool: | |
"""Asynchronously wait until the event is set. | |
If a timeout is desired, this coroutine can be wrapped | |
with :func:`asyncio.wait_for()`. | |
""" | |
with self._cond: | |
if self._flag: | |
return True | |
fut = asyncio.get_running_loop().create_future() | |
self._waiters.append(fut) | |
fut.add_done_callback(self._waiters.remove) | |
return await fut | |
def _notify_waiters(self) -> None: | |
current_loop = maybe_get_running_loop() | |
for fut in self._waiters.copy(): | |
fut_loop = fut.get_loop() | |
if fut_loop is current_loop: | |
self._maybe_set_result(fut) | |
else: | |
fut_loop.call_soon_threadsafe(self._maybe_set_result, fut) | |
def _maybe_set_result(self, fut: asyncio.Future) -> None: | |
with contextlib.suppress(asyncio.InvalidStateError): | |
fut.set_result(True) | |
def maybe_get_running_loop() -> asyncio.AbstractEventLoop | None: | |
try: | |
return asyncio.get_running_loop() | |
except RuntimeError: | |
pass | |
# TODO: write tests for thread inter-operability | |
# Copied tests from test_asyncio/test_locks.py | |
class AsyncEventTests(unittest.IsolatedAsyncioTestCase): | |
async def test_wait(self): | |
ev = Event() | |
self.assertFalse(ev.is_set()) | |
result = [] | |
async def c1(result): | |
if await ev.wait_async(): | |
result.append(1) | |
async def c2(result): | |
if await ev.wait_async(): | |
result.append(2) | |
async def c3(result): | |
if await ev.wait_async(): | |
result.append(3) | |
t1 = asyncio.create_task(c1(result)) | |
t2 = asyncio.create_task(c2(result)) | |
await asyncio.sleep(0) | |
self.assertEqual([], result) | |
t3 = asyncio.create_task(c3(result)) | |
ev.set() | |
await asyncio.sleep(0) | |
self.assertEqual([3, 1, 2], result) | |
self.assertTrue(t1.done()) | |
self.assertIsNone(t1.result()) | |
self.assertTrue(t2.done()) | |
self.assertIsNone(t2.result()) | |
self.assertTrue(t3.done()) | |
self.assertIsNone(t3.result()) | |
async def test_wait_on_set(self): | |
ev = Event() | |
ev.set() | |
res = await ev.wait_async() | |
self.assertTrue(res) | |
async def test_wait_cancel(self): | |
ev = Event() | |
wait = asyncio.create_task(ev.wait_async()) | |
asyncio.get_running_loop().call_soon(wait.cancel) | |
with self.assertRaises(asyncio.CancelledError): | |
await wait | |
self.assertFalse(ev._waiters) | |
async def test_clear(self): | |
ev = Event() | |
self.assertFalse(ev.is_set()) | |
ev.set() | |
self.assertTrue(ev.is_set()) | |
ev.clear() | |
self.assertFalse(ev.is_set()) | |
async def test_clear_with_waiters(self): | |
ev = Event() | |
result = [] | |
async def c1(result): | |
if await ev.wait_async(): | |
result.append(1) | |
return True | |
t = asyncio.create_task(c1(result)) | |
await asyncio.sleep(0) | |
self.assertEqual([], result) | |
ev.set() | |
ev.clear() | |
self.assertFalse(ev.is_set()) | |
ev.set() | |
ev.set() | |
self.assertEqual(1, len(ev._waiters)) | |
await asyncio.sleep(0) | |
self.assertEqual([1], result) | |
self.assertEqual(0, len(ev._waiters)) | |
self.assertTrue(t.done()) | |
self.assertTrue(t.result()) | |
# Copied tests from test_threading.py | |
try: | |
from test import lock_tests | |
except ImportError: | |
pass | |
else: | |
class EventTests(lock_tests.EventTests): | |
eventtype = staticmethod(Event) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment