Created
September 8, 2018 13:56
-
-
Save chrahunt/fc7add6ece26df4533061eedebf321b1 to your computer and use it in GitHub Desktop.
Websocket test library example.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
import inspect | |
import os | |
import re | |
import subprocess | |
import sys | |
import threading | |
from robot.api import logger | |
from robot.api.deco import keyword | |
import websockets | |
sys.path.append(os.path.dirname(__file__)) | |
import messages_pb2 | |
def parse_units(spec): | |
m = re.match(r'(?P<value>[0-9.]+)\s*(?P<units>[^0-9]+)', spec) | |
if not m: | |
return | |
times = { | |
'ms': 0.001, | |
's': 1, | |
'sec': 1, | |
'm': 60, | |
'min': 60, | |
'mins': 60, | |
'h': 3600, | |
'hr': 3600, | |
'hrs': 3600, | |
} | |
return float(m.group('value')) * times[m.group('units')] | |
class Client(object): | |
def __init__(self, id, ws, loop=asyncio.get_event_loop()): | |
""" | |
Args: | |
(websocket) ws the connected websocket | |
(asyncio.EventLoop) loop | |
""" | |
self._id = id | |
self.ws = ws | |
self._loop = loop | |
# Available incoming messages that haven't been filtered out. | |
self._queue = asyncio.Queue(loop=self._loop) | |
# Filters that pre-remove incoming messages. | |
self._filters = [] | |
# TODO: Add completion listener for errors. | |
self._recv_fut = asyncio.run_coroutine_threadsafe( | |
self._recv(), self._loop) | |
async def add_filter(self, callback): | |
""" | |
Filter takes a message and returns a boolean indicating whether | |
it accepts the message or not. | |
""" | |
self._filters.append(callback) | |
async def recv(self): | |
while True: | |
msg = await self._queue.get() | |
if not any(f(msg) for f in self._filters): | |
return msg | |
async def _recv(self): | |
""" | |
Inner receive loop. | |
""" | |
while True: | |
logger.console(f'{self._id}: Trying receive') | |
try: | |
data = await self.ws.recv() | |
except ConnectionClosed as e: | |
logger.console(f'{self._id}: Disconnected') | |
return | |
logger.console(f'{self._id}: Received') | |
msg = messages_pb2.ServerMessage() | |
msg.ParseFromString(data) | |
self._queue.put_nowait(msg) | |
class ClientLibrary(object): | |
""" | |
Client library has one event loop that runs in a separate thread. | |
""" | |
def __init__(self): | |
self.clients = {} | |
self.servers = {} | |
self._loop = asyncio.new_event_loop() | |
self._thread = threading.Thread( | |
target=self._loop.run_forever, name='receiver') | |
self._thread.daemon = True | |
self._thread.start() | |
def start_server(self, server_id, config): | |
app = os.path.join( | |
os.path.dirname(__file__), | |
'..', '..', '..', 'build-clang', 'bin', 'app') | |
# TODO: Dynamic port | |
cmd = [app, '-c', config] | |
self.servers[server_id] = subprocess.Popen( | |
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
def stop_server(self, server_id): | |
server = self.servers[server_id] | |
# SIGTERM | |
server.terminate() | |
# TODO: Fail based on return code. | |
stdout, stderr = server.communicate() | |
logger.debug(f'Stdout: {stdout}') | |
logger.debug(f'Stderr: {stderr}') | |
def start_client(self, client_id, config): | |
# TODO: Dynamic port | |
#port = config['server']['websocket']['port'] | |
port = 8080 | |
address = f'ws://localhost:{port}' | |
ws = self._run(websockets.client.connect(address, loop=self._loop)) | |
self.clients[client_id] = Client(client_id, ws, loop=self._loop) | |
def stop_client(self, client_id): | |
client = self._get_client(client_id) | |
self._run(client.ws.close()) | |
del self.clients[client_id] | |
def make_message(self, message_type): | |
""" | |
Helper method for creating a message in user keywords. | |
""" | |
return messages_pb2.ClientMessage() | |
def send_message(self, client_id, message): | |
""" | |
Arguments: | |
(str) client_id | |
(messages_pb2.ClientMessage) message | |
""" | |
client = self._get_client(client_id) | |
self._run(client.ws.send(message.SerializeToString())) | |
def get_message(self, client_id, timeout=None): | |
""" | |
Wait for and return a message from a client with the given id. | |
Arguments: | |
(str) client_id | |
(str) timeout as string | |
Returns: | |
(message_pb2.ServerMessage) | |
""" | |
client = self._get_client(client_id) | |
if timeout: | |
timeout = parse_units(timeout) | |
return self._run(client.recv(), timeout=timeout) | |
def _run(self, coro, timeout=None): | |
"""Dispatch an awaitable to the event loop.""" | |
async def wrapper(): | |
return await coro | |
return asyncio.run_coroutine_threadsafe( | |
wrapper(), self._loop).result(timeout) | |
def _get_client(self, client_id): | |
try: | |
return self.clients[client_id] | |
except KeyError: | |
raise AssertionError(f'Could not find client with id {client_id}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment