Skip to content

Instantly share code, notes, and snippets.

@pitrou
Last active October 18, 2017 17:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pitrou/719e73c1df51e817d618186833a6e2cc to your computer and use it in GitHub Desktop.
Save pitrou/719e73c1df51e817d618186833a6e2cc to your computer and use it in GitHub Desktop.
import struct
try:
from time import perf_counter as clock
except ImportError:
from time import time as clock
import asyncio
#import uvloop
#asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class FrameDecoder:
"""
Framing layer mixin with custom buffering logic.
"""
def connection_made(self, transport):
self.transport = transport
self._header = bytearray()
self._msg_size = None
def data_received(self, data):
if self._msg_size is None:
nhead = min(len(data), 8 - len(self._header))
self._header += data[:nhead]
data = memoryview(data)[nhead:]
if len(self._header) >= 8:
# Got length bytes
self._msg_size = struct.unpack('L', self._header[:8])[0]
self._buffers = []
self._buf_size = 0
self._header.clear()
if data:
msg = None
remaining = self._msg_size - self._buf_size
n = len(data)
if n >= remaining:
self._buffers.append(memoryview(data)[:remaining])
data = data[remaining:]
msg = b''.join(self._buffers)
self._buffers = None
self._msg_size = None
self._buf_size = None
self.message_received(msg)
if data:
self.data_received(data)
else:
self._buffers.append(data)
self._buf_size += n
def message_received(self, msg):
raise NotImplementedError
def send_message(self, msg):
self.transport.write(struct.pack('L', len(msg)))
self.transport.write(msg)
class BenchServerProtocol(FrameDecoder, asyncio.Protocol):
def connection_lost(self, exc):
print('The client closed the connection:', exc)
def message_received(self, msg):
print('server', len(msg))
self.send_message(msg)
class BenchClientProtocol(FrameDecoder, asyncio.Protocol):
def __init__(self):
self._evt_done = asyncio.Event()
def connection_lost(self, exc):
print('The server closed the connection:', exc)
def message_received(self, msg):
print('client', len(msg))
self._evt_done.set()
async def wait_until_complete(self):
await self._evt_done.wait()
async def f():
data = b"x" * (100 * 1000**2) # 100 MB
niters = 5
loop = asyncio.get_event_loop()
server = await loop.create_server(BenchServerProtocol, '127.0.0.1', 8000)
start = clock()
for i in range(5):
_, client = await loop.create_connection(BenchClientProtocol,
'127.0.0.1', 8000)
client.send_message(data)
await client.wait_until_complete()
end = clock()
server.close()
dt = end - start
rate = len(data) * niters / dt
print("duration: %s => rate: %d MB/s"
% (dt, rate / 1e6))
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(f())
loop.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment