Skip to content

Instantly share code, notes, and snippets.

@discosultan
Created May 15, 2019 17:46
Show Gist options
  • Save discosultan/b9663c2575b25a9ca02335e4e9617394 to your computer and use it in GitHub Desktop.
Save discosultan/b9663c2575b25a9ca02335e4e9617394 to your computer and use it in GitHub Desktop.
@asynccontextmanager
async def ws_connect_with_refresh(
session: ClientSession,
url: str,
interval: int,
loads: Callable[[str], Any],
take_until: Callable[[Any, Any], bool]) -> AsyncIterator[AsyncIterable[Any]]:
conn = session.ws_connect(url)
ws = await conn.__aenter__()
async def inner() -> AsyncIterable[Any]:
nonlocal conn, ws
to_close_conn = None
to_close_ws = None
while True:
timeout_task = asyncio.create_task(asyncio.sleep(interval))
while True:
receive_task = asyncio.create_task(ws.receive())
done, pending = await asyncio.wait(
[receive_task, timeout_task], return_when=asyncio.FIRST_COMPLETED)
if receive_task in done:
yield loads(_process_ws_msg(ws, receive_task.result()))
if timeout_task in done:
_aiohttp_log.info('refreshing ws connection')
to_close_conn = conn
to_close_ws = ws
conn = session.ws_connect(url)
ws = await conn.__aenter__()
new_data = loads(_process_ws_msg(ws, await ws.receive()))
async for old_msg in to_close_ws:
old_data = loads(_process_ws_msg(to_close_ws, old_msg))
if take_until(old_data, new_data):
yield old_data
else:
break
yield new_data
await to_close_ws.close()
await to_close_conn.__aexit__(None, None, None)
break
try:
yield inner()
finally:
await ws.close()
await conn.__aexit__(None, None, None)
def _process_ws_msg(ws: ClientWebSocketResponse, msg: aiohttp.WSMessage) -> str:
# Note that ping message is by default automatically handled by aiohttp by sending pong.
if msg.type is aiohttp.WSMsgType.CLOSED:
_aiohttp_log.error(f'ws connection closed unexpectedly ({msg})')
raise NotImplementedError(':/')
return msg.data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment