Created
April 2, 2016 19:09
-
-
Save JulienPalard/a95a7fcf0b4c21519b4b5c928b251ea2 to your computer and use it in GitHub Desktop.
Aggregation of asynchronous iterables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
SOON = 'SOON' | |
PAIRED = 'PAIRED' | |
class AsyncZip: | |
"""Aggregates async iterables, like `zip` or `select`. | |
The current state is stored in `self.iterating` as the `__anext__()` | |
coroutine. | |
When a `__anext__()` is done, we're immediatly replacing it with the | |
next (simply calling again __anext__()) in the `iterating` attribute. | |
There is not option to ask AsyncZip to stop iterating to the | |
shortest iterable, just break yourself when you see an iterable is | |
exhausted (having `done()` to `True` and `exception()` to | |
`StopAsyncIteration`). | |
""" | |
def __init__(self, *asynchronous_iterables, loop=None, sync=SOON): | |
""" | |
sync can take two values: | |
PAIRED: like zip, to get values from each iterables for each iterations | |
SOON: like select, to get values as soon as possible, so basically | |
one value per iteration (the others in a pending state). | |
""" | |
self.sync = sync | |
if loop is not None: | |
self._loop = loop | |
else: | |
self._loop = asyncio.events.get_event_loop() | |
self.asynchronous_iterables = asynchronous_iterables | |
self.iterating = [asyncio.ensure_future(iterable.__anext__(), | |
loop=self._loop) | |
for iterable in asynchronous_iterables] | |
async def __aiter__(self): | |
return self | |
def should_wait(self): | |
if all(f.done() and isinstance(f.exception(), StopAsyncIteration) for | |
f in self.iterating): | |
return False | |
if self.sync == SOON: | |
return not any(f.done() and | |
not isinstance(f.exception(), StopAsyncIteration) for | |
f in self.iterating) | |
if self.sync == PAIRED: | |
return not all(f.done() and | |
not isinstance(f.exception(), StopAsyncIteration) for | |
f in self.iterating) | |
async def __anext__(self): | |
waiter = asyncio.futures.Future(loop=self._loop) | |
counter = 0 | |
def _on_completion(f): | |
nonlocal counter | |
counter -= 1 | |
if self.sync == SOON or counter <= 0: | |
if not waiter.done(): | |
waiter.set_result(None) | |
listenning_for = set() | |
for i, f in enumerate(self.iterating): | |
# Skip already empty iterators | |
if f.done() and isinstance(f.exception(), StopAsyncIteration): | |
continue | |
counter += 1 | |
listenning_for.add(i) | |
f.add_done_callback(_on_completion) | |
if self.should_wait(): | |
await waiter | |
results = [] | |
stop_async_iterations = 0 | |
for i, f in enumerate(self.iterating): | |
results.append(f) | |
if i not in listenning_for: | |
stop_async_iterations += 1 | |
continue | |
f.remove_done_callback(_on_completion) | |
if f.done() and isinstance(f.exception(), StopAsyncIteration): | |
stop_async_iterations += 1 | |
elif f.done(): | |
self.iterating[i] = asyncio.ensure_future( | |
self.asynchronous_iterables[i].__anext__(), | |
loop=self._loop) | |
if stop_async_iterations == len(self.iterating): | |
raise StopAsyncIteration | |
return results | |
if __name__ == '__main__': | |
class DummyAsyncIterable: | |
def __init__(self, items, latency=0): | |
self.items = items | |
self.latency = latency | |
async def __aiter__(self): | |
return self | |
async def __anext__(self): | |
try: | |
if self.latency != 0: | |
await asyncio.sleep(self.latency) | |
return self.items.pop(0) | |
except IndexError: | |
raise StopAsyncIteration | |
def _repr_task(task): | |
"""Just give: | |
- . if it's a StopAsyncIteration, | |
- ~ if waiting | |
- ? if another exception | |
- the result otherwise | |
""" | |
if task.done() and isinstance(task.exception(), StopAsyncIteration): | |
return '.' | |
if task.done() and isinstance(task.exception(), Exception): | |
print(type(task.exception())) | |
return '?' | |
if task.done(): | |
return task.result() | |
return '~' | |
async def test_AsyncZip(listA, listB, listC, lag, sync, | |
expect): | |
got = [] | |
async for items in AsyncZip(DummyAsyncIterable(list(listA[0]), | |
listA[1]), | |
DummyAsyncIterable(list(listB[0]), | |
listB[1]), | |
DummyAsyncIterable(list(listC[0]), | |
listC[1]), | |
sync=sync): | |
got.append(''.join(_repr_task(task) for task in items)) | |
await asyncio.sleep(lag) | |
print(', '.join(got)) | |
assert(', '.join(got) == expect) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('123', 0), ('abcd', .1), ('ABCD', .25), .01, SOON, | |
'1~~, 2~~, 3~~, .a~, .b~, .~A, .c~, .d~, .~B, ..C, ..D')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('123', 0), ('abcd', .1), ('ABCD', .25), .47, SOON, | |
'1~~, 2aA, 3bB, .cC, .dD')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('123', 0), ('abcd', .1), ('ABCD', .3), .2, SOON, | |
'1~~, 2a~, 3bA, .c~, .dB, ..C, ..D')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('123', 0), ('abcd', .3), ('ABCD', .1), .2, SOON, | |
'1~~, 2~A, 3aB, .~C, .bD, .c., .d.')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('123', 0), ('abcd', .1), ('ABCD', .25), .01, PAIRED, | |
'1aA, 2bB, 3cC, .dD')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('12', 0), ('abcd', .1), ('ABCDE', .25), .01, PAIRED, | |
'1aA, 2bB, .cC, .dD, ..E')) | |
asyncio.get_event_loop().run_until_complete(test_AsyncZip( | |
('12', 0), ('abcd', .1), ('ABCDE', .25), .21, PAIRED, | |
'1aA, 2bB, .cC, .dD, ..E')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment