Skip to content

Instantly share code, notes, and snippets.

@richardsheridan
Last active January 2, 2023 22:14
Show Gist options
  • Save richardsheridan/42d99cbcfcc1d77bc61890b2fae4bfaa to your computer and use it in GitHub Desktop.
Save richardsheridan/42d99cbcfcc1d77bc61890b2fae4bfaa to your computer and use it in GitHub Desktop.
map_concurrently_in_subthread_trio
import queue
import random
from functools import partial
from time import sleep, perf_counter
import trio
CONCURRENCY_LIMIT = 8
limiter = trio.CapacityLimiter(CONCURRENCY_LIMIT)
def sync_work(item):
sleep(item[1])
return item
async def asyncify_iterator(iter, limiter=None):
sentinel = object()
while (
result := await trio.to_thread.run_sync(next, iter, sentinel, limiter=limiter)
) is not sentinel:
yield result
async def worker_task(i, func, send_chan, task_status):
async with limiter, send_chan:
task_status.started()
result = await trio.to_thread.run_sync(func, limiter=trio.CapacityLimiter(1))
await send_chan.send((i, result))
async def result_sorter(result_send, recv_chan):
results = {}
async with recv_chan:
j = 0
async for i, result in recv_chan:
if i != j:
results[i] = result
continue
while True:
await result_send.send(result)
j += 1
if j in results:
result = results.pop(j)
else:
break
async def amain(run_data_queue, func, items, args, kwargs):
send_chan, recv_chan = trio.open_memory_channel(0)
result_send, result_recv = trio.open_memory_channel(0)
async with trio.open_nursery() as nursery:
run_data_queue.put(
(
trio.lowlevel.current_trio_token(),
nursery.cancel_scope.cancel,
result_recv.receive,
)
)
nursery.start_soon(result_sorter, result_send, recv_chan)
item_aiter = asyncify_iterator(iter(items), limiter)
i = 0
async for item in item_aiter:
await nursery.start(
worker_task, i, partial(func, item, *args, **kwargs), send_chan.clone()
)
i += 1
send_chan.close()
def map_concurrently_in_subthread_trio(func, items, args=(), kwargs={}):
run_data_queue = queue.SimpleQueue()
def trio_main():
return trio.run(amain, run_data_queue, func, items, args, kwargs)
def deliver(result):
run_data_queue.put(result)
trio.lowlevel.start_thread_soon(trio_main, deliver)
token, cancel, receive = run_data_queue.get()
while True:
try:
value = trio.from_thread.run(receive, trio_token=token)
except (trio.RunFinishedError, trio.Cancelled):
break # don't unwrap here, to avoid chaining exceptions
try:
yield value
except BaseException:
try:
token.run_sync_soon(cancel)
except trio.RunFinishedError:
pass
run_data_queue.get().unwrap()
raise
return run_data_queue.get().unwrap()
if __name__ == "__main__":
t0 = perf_counter()
for x in map_concurrently_in_subthread_trio(
sync_work, ((i, random.random()) for i in range(100))
):
print(x)
sleep(0.1)
t = perf_counter()
print(t - t0, "is less than 50")
@richardsheridan
Copy link
Author

Fixed 1 and 3 above, also now responds to backpressure from the queue if results are being consumed slowly.

@richardsheridan
Copy link
Author

Now uses trio.from_thread.run* to improve readability 100x and make result_sender fully natively async.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment