Skip to content

Instantly share code, notes, and snippets.

@thegamecracks
Last active April 19, 2024 04:48
Show Gist options
  • Save thegamecracks/831603c4adac9cab94b4274846b59e0f to your computer and use it in GitHub Desktop.
Save thegamecracks/831603c4adac9cab94b4274846b59e0f to your computer and use it in GitHub Desktop.
A thread-safe event compatible with asyncio
# Requires Python>=3.11
import asyncio
import contextlib
import threading
import unittest
class Event(threading.Event): # TODO: any better name?
"""A thread-safe event compatible with asyncio."""
_cond: threading.Condition
_flag: bool
_waiters: list[asyncio.Future[bool]]
def __init__(self) -> None:
super().__init__()
self._waiters = []
# FIXME: might be safer to vendor threading.Event instead
assert hasattr(self, "_cond"), "_cond missing from threading.Event"
assert hasattr(self, "_flag"), "_flag missing from threading.Event"
def set(self):
with self._cond:
self._flag = True
self._cond.notify_all()
self._notify_waiters()
async def wait_async(self) -> bool:
"""Asynchronously wait until the event is set.
If a timeout is desired, this coroutine can be wrapped
with :func:`asyncio.wait_for()`.
"""
with self._cond:
if self._flag:
return True
fut = asyncio.get_running_loop().create_future()
self._waiters.append(fut)
fut.add_done_callback(self._waiters.remove)
return await fut
def _notify_waiters(self) -> None:
current_loop = maybe_get_running_loop()
for fut in self._waiters.copy():
fut_loop = fut.get_loop()
if fut_loop is current_loop:
self._maybe_set_result(fut)
else:
fut_loop.call_soon_threadsafe(self._maybe_set_result, fut)
def _maybe_set_result(self, fut: asyncio.Future) -> None:
with contextlib.suppress(asyncio.InvalidStateError):
fut.set_result(True)
def maybe_get_running_loop() -> asyncio.AbstractEventLoop | None:
try:
return asyncio.get_running_loop()
except RuntimeError:
pass
# TODO: write tests for thread inter-operability
# Copied tests from test_asyncio/test_locks.py
class AsyncEventTests(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
ev = Event()
self.assertFalse(ev.is_set())
result = []
async def c1(result):
if await ev.wait_async():
result.append(1)
async def c2(result):
if await ev.wait_async():
result.append(2)
async def c3(result):
if await ev.wait_async():
result.append(3)
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
t3 = asyncio.create_task(c3(result))
ev.set()
await asyncio.sleep(0)
self.assertEqual([3, 1, 2], result)
self.assertTrue(t1.done())
self.assertIsNone(t1.result())
self.assertTrue(t2.done())
self.assertIsNone(t2.result())
self.assertTrue(t3.done())
self.assertIsNone(t3.result())
async def test_wait_on_set(self):
ev = Event()
ev.set()
res = await ev.wait_async()
self.assertTrue(res)
async def test_wait_cancel(self):
ev = Event()
wait = asyncio.create_task(ev.wait_async())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(ev._waiters)
async def test_clear(self):
ev = Event()
self.assertFalse(ev.is_set())
ev.set()
self.assertTrue(ev.is_set())
ev.clear()
self.assertFalse(ev.is_set())
async def test_clear_with_waiters(self):
ev = Event()
result = []
async def c1(result):
if await ev.wait_async():
result.append(1)
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
ev.set()
ev.clear()
self.assertFalse(ev.is_set())
ev.set()
ev.set()
self.assertEqual(1, len(ev._waiters))
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertEqual(0, len(ev._waiters))
self.assertTrue(t.done())
self.assertTrue(t.result())
# Copied tests from test_threading.py
try:
from test import lock_tests
except ImportError:
pass
else:
class EventTests(lock_tests.EventTests):
eventtype = staticmethod(Event)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment