Last active
June 8, 2022 12:17
-
-
Save graingert/cacf3adc12ce542ba26eedf651e12296 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from distributed import default_client | |
import sys | |
import asyncio | |
import concurrent.futures | |
from concurrent import futures | |
import contextlib | |
from tornado.ioloop import IOLoop | |
import random | |
import time | |
import dask | |
from dask.distributed import Client, LocalCluster, as_completed | |
def client_workload(thread_num, j, data): | |
stime = random.uniform(2, 4) | |
time.sleep(stime) | |
size = "<no data>" | |
if data is not None: | |
size = len(data) | |
print(f"[t{thread_num}] -- {j} -- {stime:.2f} -- {size}") | |
return stime | |
def _run_and_close_tornado(async_fn, /, *args, **kwargs): | |
tornado_loop = None | |
async def inner_fn(): | |
nonlocal tornado_loop | |
tornado_loop = IOLoop.current() | |
return await async_fn(*args, **kwargs) | |
try: | |
return asyncio.run(inner_fn()) | |
finally: | |
tornado_loop.close(all_fds=True) | |
@contextlib.contextmanager | |
def loop_in_thread(): | |
loop_started = concurrent.futures.Future() | |
with concurrent.futures.ThreadPoolExecutor( | |
1, thread_name_prefix="test IOLoop" | |
) as tpe: | |
async def run(): | |
io_loop = IOLoop.current() | |
stop_event = asyncio.Event() | |
loop_started.set_result((io_loop, stop_event)) | |
await stop_event.wait() | |
# run asyncio.run in a thread and collect exceptions from *either* | |
# the loop failing to start, or failing to close | |
ran = tpe.submit(_run_and_close_tornado, run) | |
for f in concurrent.futures.as_completed((loop_started, ran)): | |
if f is loop_started: | |
io_loop, stop_event = loop_started.result() | |
try: | |
yield io_loop | |
finally: | |
io_loop.add_callback(stop_event.set) | |
elif f is ran: | |
# if this is the first iteration the loop failed to start | |
# if it's the second iteration the loop has finished or | |
# the loop failed to close and we need to raise the exception | |
ran.result() | |
return | |
def thread_evtloop(thread_num): | |
with loop_in_thread() as loop, LocalCluster(n_workers=3, loop=loop) as cluster: | |
total = 0 | |
with Client(cluster.scheduler_address, loop=loop) as client: | |
[data_f] = client.scatter(["some long data..."]) | |
fts = [client.submit(client_workload, thread_num, j, data_f) for j in range(10)] | |
assert default_client() is client | |
for _, result in as_completed(fts, with_results=True): | |
total += result | |
print(f"[t{thread_num}] {total}") | |
def main(): | |
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as e: | |
fts = [e.submit(thread_evtloop, i) for i in range(3)] | |
for fut in concurrent.futures.as_completed(fts): | |
print(fut) | |
fut.result() | |
return 0 | |
if __name__ == '__main__': | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment