Skip to content

Instantly share code, notes, and snippets.

@crosstyan
Last active June 12, 2024 19:28
Show Gist options
  • Save crosstyan/ad4930f46550d2b1a18b1156fda1cbaf to your computer and use it in GitHub Desktop.
Save crosstyan/ad4930f46550d2b1a18b1156fda1cbaf to your computer and use it in GitHub Desktop.
a stupid anyio with multiprocess
import random
import time
from multiprocessing import cpu_count
from typing import (
Any,
Awaitable,
Callable,
Final,
Generic,
Iterable,
Optional,
Protocol,
Tuple,
TypedDict,
TypeVar,
TypeVarTuple,
Union,
Unpack,
cast,
)
import anyio
import multiprocess as mp
import numpy as np
import signal
from loguru import logger
from multiprocess.context import BaseContext, DefaultContext, Process, assert_spawning
from multiprocess.managers import BaseManager, SharedMemoryManager, SyncManager
from multiprocess.pool import ApplyResult, Pool
from multiprocess.process import BaseProcess
from multiprocess.queues import Empty, Full, Queue
from multiprocess.shared_memory import ShareableList, SharedMemory
from multiprocess.synchronize import Condition
BUF_SIZE: Final = 16
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
def create_process(
target: Callable[[Unpack[PosArgsT]], T_Retval],
args: Tuple[Unpack[PosArgsT]],
name: str | None = None,
daemon: bool | None = None,
):
return Process(target=target, args=args, name=name, daemon=daemon)
_exit_handler: Optional[Callable[[int, Any], None]] = None
# https://superfastpython.com/multiprocessing-condition-variable-in-python/
def task(id: int, cv: Condition, sm: SharedMemory, sq: Queue, oq: Queue):
global _exit_handler
logger.info("Starting task {}", id)
assert _exit_handler is None
def exit_handler(_sig_num: int = 0, _frame: Any = None):
logger.info("Task {} done", id)
sm.close()
_exit_handler = exit_handler
signal.signal(signal.SIGTERM, _exit_handler)
_buf = np.ndarray((BUF_SIZE,), dtype=np.uint8)
# sync queue to indicate that the process is ready
sq.put({"id": id})
try:
while True:
with cv:
cv.wait()
assert sm.buf is not None
temp = np.frombuffer(sm.buf, dtype=np.uint8, count=BUF_SIZE, offset=0)
# copy the shared memory to the local buffer
_buf[:] = temp[:]
s = np.sum(_buf)
oq.put({"id": id, "sum": s})
finally:
exit_handler()
def main():
sync_man = SyncManager()
mem_man = SharedMemoryManager()
sync_man.start()
mem_man.start()
ctx = cast(BaseContext, sync_man._ctx)
oq = cast(Queue, sync_man.Queue()) # output queue # type: ignore
sm = mem_man.SharedMemory(16)
cv = Condition(ctx=ctx)
count: int = cpu_count()
sq = cast(Queue, sync_man.Queue(count)) # sync queue # type: ignore
ps = [create_process(target=task, args=(i, cv, sm, sq, oq)) for i in range(count)]
for p in ps:
p.start()
for _ in range(len(ps)):
init = sq.get()
logger.info("Process {} started", init["id"])
for _ in range(24):
assert sm.buf is not None
buf = np.frombuffer(sm.buf, dtype=np.uint8, count=BUF_SIZE, offset=0)
buf[:] = np.random.randint(0, 255, BUF_SIZE)
logger.info("buf={}", buf)
with cv:
cv.notify()
logger.info(f"Result {oq.get()}")
logger.info("Main process done")
for p in ps:
p.terminate()
p.join()
p.close()
try:
sm.close()
sm.unlink()
except BufferError:
pass
sync_man.shutdown()
mem_man.shutdown()
if __name__ == "__main__":
main()
from typing import (
Any,
Awaitable,
Callable,
Final,
Generic,
Optional,
TypeVar,
TypeVarTuple,
TypedDict,
Union,
Unpack,
Protocol,
cast,
)
import anyio
import multiprocess as mp
from loguru import logger
from multiprocess.context import BaseContext, DefaultContext, assert_spawning
from multiprocess.managers import BaseManager, SharedMemoryManager, SyncManager
from multiprocess.pool import ApplyResult, Pool
from multiprocess.process import BaseProcess as Process
from multiprocess.queues import Empty, Full, Queue
from multiprocess.shared_memory import ShareableList, SharedMemory
from multiprocess.synchronize import Condition
# I don't know what to do now
T = TypeVar("T")
# https://superfastpython.com/multiprocessing-condition-variable-in-python/
class QueueProxy(Generic[T]):
"""
An anyio type-safe queue for multiprocessing.
This class provides an asynchronous interface to a multiprocessing Queue,
allowing it to be used safely with anyio without blocking the event loop.
Note
------
using the default implementation of get_state and set_state
"""
_q: Queue
_ctx: BaseContext
def __init__(self, queue: Queue, ctx: BaseContext):
"""
Initialize the QueueProxy.
:param queue: The multiprocessing Queue to wrap.
:param ctx: The multiprocessing context used to create the Queue.
"""
self._q = queue
self._ctx = ctx
@staticmethod
def from_manager(manager: BaseManager, size: int = 0) -> "QueueProxy[T]":
"""
Create a new QueueProxy from a multiprocessing Manager.
:param manager: The Manager to create the Queue from.
:param size: The maximum size of the Queue (default 0 for unlimited).
:return: A new QueueProxy instance.
"""
ctx = manager._ctx # pylint: disable=protected-access
return QueueProxy[T](manager.Queue(size), ctx=ctx)
def put(self, item: T, block: bool = True, timeout: Optional[float] = None):
"""
Put an item into the queue, synchronously.
This method wraps Queue.put() and has the same behavior.
"""
self._q.put(item, block=block, timeout=timeout)
async def async_put(self, item: T):
"""
Put an item into the queue asynchronously.
If the queue is full, this method waits until a free slot is available
before adding the item, without blocking the event loop.
"""
while True:
try:
return self._q.put_nowait(item)
except Full:
await anyio.sleep(0)
def get(self, block: bool = True, timeout: Optional[float] = None) -> T:
"""
Get an item from the queue, synchronously.
This method wraps Queue.get() and has the same behavior.
"""
return cast(T, self._q.get(block=block, timeout=timeout))
async def async_get(self):
"""
Get an item from the queue asynchronously.
If the queue is empty, this method waits until an item is available
without blocking the event loop.
"""
while True:
try:
return cast(T, self._q.get_nowait())
except Empty:
await anyio.sleep(0)
def put_nowait(self, item: T):
"""Put an item into the queue if a free slot is immediately available."""
self._q.put_nowait(item)
def get_nowait(self) -> T:
"""Get an item from the queue if one is immediately available."""
return cast(T, self._q.get_nowait())
@property
def queue(self) -> Queue:
"""
Get the underlying multiprocessing Queue.
Use this property when passing the Queue to a multiprocessing Process.
"""
return self._q
async def __aiter__(self):
"""
Asynchronous iterator interface to get items from the queue.
This allows using the queue with `async for` without blocking the event loop.
"""
while True:
try:
el = self._q.get(block=False)
yield cast(T, el)
except Empty:
# https://superfastpython.com/what-is-asyncio-sleep-zero/
# yield control to the event loop
await anyio.sleep(0)
async def __anext__(self):
return await self.async_get()
class ApplyResultLike(Protocol, Generic[T]):
def ready(self) -> bool: ...
def successful(self) -> bool: ...
def get(self, timeout: Optional[float]) -> T: ...
def wait(self, timeout: Optional[float]) -> None: ...
async def await_result(result: ApplyResultLike[T]) -> T:
"""
wrap an ApplyResult to an awaitable
"""
while not result.ready():
await anyio.sleep(0)
return result.get()
_a: Optional[int] = None
"""
this variable is expected to be unique in each process
"""
def init():
global _a
assert _a is None, "inited"
_a = 1
logger.info("init {}", _a)
def task(i: int, q: QueueProxy[int]):
global _a
assert _a is not None
_a += i
q.put(i, timeout=3)
return _a
TIMEOUT = 5
QUEUE_SIZE = 24
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
def safe_apply_sync(
pool: Pool,
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
callback: Optional[Callable[[T_Retval], None]] = None,
error_callback: Optional[Callable[[Any], None]] = None,
) -> ApplyResultLike[T_Retval]:
"""
A type-safe wrapper around `Pool.apply_async()`
"""
return pool.apply_async(
func=func,
args=args,
callback=callback,
error_callback=error_callback,
)
def main():
count: int = mp.cpu_count()
logger.info("cpu count is {}", count)
man = SyncManager()
man.start()
q = QueueProxy[int].from_manager(man, QUEUE_SIZE)
p = Pool(processes=count, initializer=init)
for _ in range(1_000):
# https://superfastpython.com/multiprocessing-pool-asyncresult/
ar = safe_apply_sync(p, task, 1, q)
async def consumer():
await q.async_put(10)
first = await q.async_get()
logger.info("first={}", first)
acc = 0
with anyio.move_on_after(TIMEOUT) as cancel_scope:
async for i in q:
acc += i
cancel_scope.deadline = anyio.current_time() + TIMEOUT
logger.info("acc={}", acc)
anyio.run(consumer)
p.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment