Skip to content

Instantly share code, notes, and snippets.

@graingert
Last active June 8, 2022 12:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save graingert/cacf3adc12ce542ba26eedf651e12296 to your computer and use it in GitHub Desktop.
Save graingert/cacf3adc12ce542ba26eedf651e12296 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from distributed import default_client
import sys
import asyncio
import concurrent.futures
from concurrent import futures
import contextlib
from tornado.ioloop import IOLoop
import random
import time
import dask
from dask.distributed import Client, LocalCluster, as_completed
def client_workload(thread_num, j, data):
stime = random.uniform(2, 4)
time.sleep(stime)
size = "<no data>"
if data is not None:
size = len(data)
print(f"[t{thread_num}] -- {j} -- {stime:.2f} -- {size}")
return stime
def _run_and_close_tornado(async_fn, /, *args, **kwargs):
tornado_loop = None
async def inner_fn():
nonlocal tornado_loop
tornado_loop = IOLoop.current()
return await async_fn(*args, **kwargs)
try:
return asyncio.run(inner_fn())
finally:
tornado_loop.close(all_fds=True)
@contextlib.contextmanager
def loop_in_thread():
loop_started = concurrent.futures.Future()
with concurrent.futures.ThreadPoolExecutor(
1, thread_name_prefix="test IOLoop"
) as tpe:
async def run():
io_loop = IOLoop.current()
stop_event = asyncio.Event()
loop_started.set_result((io_loop, stop_event))
await stop_event.wait()
# run asyncio.run in a thread and collect exceptions from *either*
# the loop failing to start, or failing to close
ran = tpe.submit(_run_and_close_tornado, run)
for f in concurrent.futures.as_completed((loop_started, ran)):
if f is loop_started:
io_loop, stop_event = loop_started.result()
try:
yield io_loop
finally:
io_loop.add_callback(stop_event.set)
elif f is ran:
# if this is the first iteration the loop failed to start
# if it's the second iteration the loop has finished or
# the loop failed to close and we need to raise the exception
ran.result()
return
def thread_evtloop(thread_num):
with loop_in_thread() as loop, LocalCluster(n_workers=3, loop=loop) as cluster:
total = 0
with Client(cluster.scheduler_address, loop=loop) as client:
[data_f] = client.scatter(["some long data..."])
fts = [client.submit(client_workload, thread_num, j, data_f) for j in range(10)]
assert default_client() is client
for _, result in as_completed(fts, with_results=True):
total += result
print(f"[t{thread_num}] {total}")
def main():
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as e:
fts = [e.submit(thread_evtloop, i) for i in range(3)]
for fut in concurrent.futures.as_completed(fts):
print(fut)
fut.result()
return 0
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment