Forked from petri/gist:46bb6b8eac8ea5b8f01c90e7414d7951
Created
September 27, 2020 18:42
-
-
Save HQJaTu/1f6b2ed67ae53b24fd3f350f611e7eb1 to your computer and use it in GitHub Desktop.
Python asyncio ordering test - single-connection multiplex version with asyncio.Queue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# vim: autoindent tabstop=4 shiftwidth=4 expandtab softtabstop=4 filetype=python | |
# Proof-of-Concept for https://stackoverflow.com/q/64017656/1548275 | |
# Do Python asyncio Streams maintain order over multiple writers and readers? | |
import sys | |
import argparse | |
import logging | |
import asyncio | |
from random import randrange | |
from pprint import pprint | |
log = logging.getLogger(__name__) | |
DEFAULT_TCP_PORT = 8888 | |
def _setup_logger(): | |
log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") | |
console_handler = logging.StreamHandler(sys.stderr) | |
console_handler.setFormatter(log_formatter) | |
console_handler.propagate = False | |
logging.getLogger().addHandler(console_handler) | |
log.setLevel(logging.DEBUG) | |
logging.getLogger("asyncio").setLevel(logging.DEBUG) | |
def run_server(loop, tcp_port, delay_max): | |
if delay_max < 0: | |
raise ValueError("Delay cannot be negative!") | |
log.info("Starting server in localhost TCP-port: %d" % tcp_port) | |
coro = asyncio.start_server(lambda r, w: _server_coro(r, w, delay_max), '127.0.0.1', tcp_port, loop=loop) | |
server = loop.run_until_complete(coro) | |
# Serve requests until Ctrl+C is pressed | |
socket_info = server.sockets[0].getsockname() | |
log.info("Serving on: %s:%d" % (socket_info[0], socket_info[1])) | |
try: | |
loop.run_forever() | |
except KeyboardInterrupt: | |
log.debug("Stop serving on keyboard interrupt.") | |
pass | |
server.close() | |
async def _server_coro(reader, writer, delay_max): | |
awaitable_set = {asyncio.sleep(1)} | |
while True: | |
try: | |
data = await reader.read(4) | |
if not data: | |
await asyncio.sleep(1 / 1000) | |
continue | |
data_len = int.from_bytes(data, 'big') | |
data = await reader.read(data_len) | |
except asyncio.IncompleteReadError: | |
return | |
message = data.decode('UTF-8') | |
responder_task = asyncio.create_task(_server_responder_task(writer, message, delay_max)) | |
awaitable_set.add(responder_task) | |
done, awaitable_set = await asyncio.wait(awaitable_set, return_when=asyncio.FIRST_COMPLETED) | |
async def _server_responder_task(writer, message, delay_max): | |
addr = writer.get_extra_info('peername') | |
delay = randrange(delay_max) | |
log.debug("Received %s from %s" % (message, addr)) | |
message_out = "Got: '%s'" % message | |
await asyncio.sleep(delay / 1000) | |
data_out = message_out.encode('UTF-8') | |
data_out = len(data_out).to_bytes(4, byteorder='big') + data_out | |
log.debug("Sending after delay of %d ms: %s" % (delay, message_out)) | |
writer.write(data_out) | |
await writer.drain() | |
def run_client(loop, tcp_port, count_connections): | |
loop.run_until_complete(_client_coro(loop, tcp_port, count_connections)) | |
async def _client_coro(loop, tcp_port, count_connections): | |
queue = asyncio.Queue(loop=loop) | |
log.info("Running client to localhost TCP-port: %d" % tcp_port) | |
reader, writer = await asyncio.open_connection('127.0.0.1', tcp_port, loop=loop) | |
# Step 1: | |
# Create sending tasks | |
tasks = [asyncio.create_task(_client_sender_task(writer, conn_idx, queue)) for conn_idx in range(count_connections)] | |
# Step 2: | |
# Create sending task reader and place it as first task to run | |
response_reader_task = asyncio.create_task(_client_response_reader_task(reader, queue)) | |
tasks.insert(0, response_reader_task) | |
# Step 3: | |
# Run the prepared tasks | |
await asyncio.wait(tasks) | |
# Step 4: | |
# Done with the sockets | |
log.debug('Close the socket') | |
writer.close() | |
async def _client_sender_task(writer, conn_idx, queue): | |
message_out = "Test %d" % (conn_idx + 1) | |
log.debug('Send: %s' % message_out) | |
data_out = message_out.encode('UTF-8') | |
data_out = len(data_out).to_bytes(4, byteorder='big') + data_out | |
writer.write(data_out) | |
response_reader_task = asyncio.Future() | |
await queue.put(response_reader_task) | |
log.debug("Task %d queuing for response" % (conn_idx + 1)) | |
message_in = await response_reader_task | |
log.debug('Task %d received: %s' % (conn_idx + 1, message_in)) | |
async def _client_response_reader_task(reader, queue): | |
# log.debug("Executing _client_response_reader_task()") | |
while not queue.empty(): | |
response = await queue.get() | |
# log.debug("Reading network response for request from queue") | |
data = await reader.read(4) | |
data_len = int.from_bytes(data, 'big') | |
data = await reader.read(data_len) | |
message_in = data.decode('UTF-8') | |
log.debug('Received: %s' % message_in) | |
response.set_result(message_in) | |
log.debug("Done reading client responses in _client_response_reader_task()") | |
def main(): | |
parser = argparse.ArgumentParser(description='Name.com DNS tool') | |
parser.add_argument('--server', action='store_true', | |
help='Run as a test server') | |
parser.add_argument('--client', type=int, | |
metavar="CONNECTION-COUNT", | |
help='Run as a test client. Argument: number of client connections to make towards server.') | |
parser.add_argument('--port', '-p', | |
default=DEFAULT_TCP_PORT, type=int, | |
help="TCP-port for server to listen or client to connect. Default: %d" % DEFAULT_TCP_PORT) | |
parser.add_argument('--delay-max', type=int, | |
metavar="MILLISECONDS", | |
help="Random delay max. in [ms]. Both client and server.") | |
args = parser.parse_args() | |
_setup_logger() | |
if not args.server and not args.client: | |
log.error("Need either --server or --client!") | |
exit(2) | |
if args.delay_max is None: | |
log.error("Need --delay-max!") | |
exit(2) | |
if args.client and args.client < 0: | |
log.error("--client connection count needs to be positive integer!") | |
exit(2) | |
# Init async I/O | |
async_loop = asyncio.get_event_loop() | |
if args.server: | |
run_server(async_loop, args.port, args.delay_max) | |
elif args.client: | |
run_client(async_loop, args.port, args.client) | |
else: | |
raise ValueError("Internal: Duh?") | |
# In a nice and calm fashion, shut down any possible tasks that are pending. | |
for task in asyncio.Task.all_tasks(): | |
task.cancel() | |
async_loop.run_until_complete(async_loop.shutdown_asyncgens()) | |
async_loop.close() | |
log.info("Done.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment