Last active
December 22, 2023 07:03
-
-
Save linepro6/f51ac8930882ce8200f8a0ae795c214e to your computer and use it in GitHub Desktop.
Python JsonRPC in WebSockets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class APIException(Exception): | |
def __init__(self, code: int, message: str, data: dict = None): | |
self.code = code | |
self.message = message | |
self.data = data | |
def __str__(self): | |
return f"{self.message} ({self.code})" | |
class ParseErrorException(APIException): | |
def __init__(self): | |
super().__init__(-32700, "Parse Error") | |
class InvalidRequestException(APIException): | |
def __init__(self): | |
super().__init__(-32600, "Invalid Request") | |
class MethodNotFoundException(APIException): | |
def __init__(self): | |
super().__init__(-32601, "Method Not Found") | |
class InvalidParamsException(APIException): | |
def __init__(self): | |
super().__init__(-32602, "Invalid Params") | |
class InternalErrorException(APIException): | |
def __init__(self, traceback: str): | |
super().__init__(-32603, "Internal Error", {"traceback": traceback}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import websockets.client as ws_client | |
import websockets | |
from websockets.exceptions import ConnectionClosed | |
from loguru import logger | |
from typing import Callable, Union, Optional | |
import json | |
import traceback | |
import inspect | |
from functools import wraps | |
from .exceptions import * | |
def get_default_args(func): | |
signature = inspect.signature(func) | |
return { | |
k: v.default | |
for k, v in signature.parameters.items() | |
if v.default is not inspect.Parameter.empty | |
} | |
class RaiseAPIException: | |
"""captures specified exception and raise ApiErrorCode instead | |
:raises: AttributeError if code_name is not valid | |
""" | |
def __init__(self, api_exception: APIException, *captures): | |
self.captures = captures | |
self.api_exception = api_exception | |
def __enter__(self): | |
# 该方法将在进入上下文时调用 | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
# 该方法将在退出上下文时调用 | |
# exc_type, exc_val, exc_tb 分别表示该上下文内抛出的 | |
# 异常类型、异常值、错误栈 | |
if exc_type is None: | |
return False | |
if exc_type in self.captures: | |
raise self.api_exception from exc_val | |
return False | |
class JsonRPCAPIClient: | |
def __init__(self, uri: str): | |
self.__uri = uri | |
self.__seq = 0 | |
self.__local_api_bind = {} | |
self.__remote_api_future = {} | |
self.__websocket: ws_client.WebSocketClientProtocol = None | |
self.__websocket_connecting_future = None | |
self.__connected_hook = None | |
async def __fetch_remote_reply(self, id: int): | |
future = asyncio.get_event_loop().create_future() | |
self.__remote_api_future[id] = future | |
try: | |
result = await asyncio.wait_for(future, 60) | |
if "error" in result: | |
try: | |
e = result["error"] | |
raise APIException(e["code"], e["message"], e.get("data", None)) | |
except KeyError: | |
raise InternalErrorException(traceback.format_exc()) | |
elif "result" in result: | |
result = result["result"] | |
else: | |
raise InternalErrorException(json.dumps(result)) | |
return result | |
finally: | |
del self.__remote_api_future[id] | |
def __generate_seq_id(self): | |
s = self.__seq | |
self.__seq = (self.__seq + 1) % 9999999 | |
return s | |
async def __send_error(self, e: APIException, id: int = None): | |
to_send_dict = { | |
"jsonrpc": "2.0", | |
"error": { | |
"code": e.code, | |
"message": e.message | |
}, | |
"id": id | |
} | |
if e.data: | |
to_send_dict["error"]["data"] = e.data | |
await self.__websocket_connecting_future | |
await self.__websocket.send(json.dumps(to_send_dict)) | |
async def __local_api_call(self, method: str, params: Optional[Union[list, dict]], id: Optional[int]): | |
try: | |
func = self.__local_api_bind.get(method, None) | |
if not func: | |
raise MethodNotFoundException() | |
with RaiseAPIException(InvalidParamsException, TypeError): | |
if asyncio.iscoroutinefunction(func): | |
if params is None: | |
result = await func() | |
elif isinstance(params, list): | |
result = await func(*params) | |
elif isinstance(params, dict): | |
result = await func(**params) | |
else: | |
raise TypeError | |
else: | |
if params is None: | |
result = func() | |
elif isinstance(params, list): | |
result = func(*params) | |
elif isinstance(params, dict): | |
result = func(**params) | |
else: | |
raise TypeError | |
if id is not None: | |
await self.__websocket_connecting_future | |
await self.__websocket.send(json.dumps({ | |
"jsonrpc": "2.0", | |
"result": result, | |
"id": id | |
})) | |
except APIException as e: | |
await self.__send_error(e, id) | |
except Exception as e: | |
trace_info = traceback.format_exc() | |
logger.error(trace_info) | |
await self.__send_error(InternalErrorException(trace_info), id) | |
async def __main(self): | |
logger.info("[JSONRPC] Connecting to server " + self.__uri) | |
async for websocket in ws_client.connect(self.__uri): | |
logger.info("[JSONRPC] Connected to server " + self.__uri) | |
self.__websocket = websocket | |
self.__websocket_connecting_future.set_result(None) | |
if self.__connected_hook: | |
asyncio.create_task(self.__connected_hook()) | |
try: | |
# Process messages received on the connection. | |
async for message in websocket: | |
id = None | |
try: | |
with RaiseAPIException(ParseErrorException, json.JSONDecodeError): | |
data: dict = json.loads(message) | |
id = data.get("id", None) | |
if "method" in data: | |
with RaiseAPIException(InvalidRequestException, KeyError): | |
# request | |
method = data["method"] | |
if not isinstance(method, str): | |
raise KeyError(method) | |
params = data.get("params", None) | |
asyncio.create_task(self.__local_api_call(method, params, id)) | |
elif "result" in data or "error" in data: | |
if id is None: | |
logger.warning("[JSONRPC] received malform data " + message) | |
continue | |
future = self.__remote_api_future.get(id, None) | |
if future: | |
future.set_result(data) | |
else: | |
logger.warning("[JSONRPC] received malform data " + message) | |
else: | |
logger.warning("[JSONRPC] received malform data " + message) | |
except APIException as e: | |
await self.__send_error(e, id) | |
except ConnectionClosed: | |
logger.warning("[JSONRPC] Connection closed, reconnect in 5s...") | |
self.__websocket_connecting_future = asyncio.get_event_loop().create_future() | |
await asyncio.sleep(5) | |
def remote_api(self, notice=False): | |
"""函数装饰器,定义远程对端 API,函数体为 pass 即可(不会执行,纯定义) | |
注意:传参时只能选择位置参数(数组)或具名参数(字典)其中一种传参形式 | |
Args: | |
notice (bool, optional): 是否仅通知,即不需要对端返回结果,默认为否 | |
""" | |
def decorator(func): | |
async def wrapper(*args, **kwargs): | |
method = func.__name__ | |
to_send_dict = { | |
"jsonrpc": "2.0", | |
"method": method | |
} | |
default_kwargs = get_default_args(func) | |
default_kwargs.update(kwargs) | |
kwargs = default_kwargs | |
if args and kwargs: | |
raise ValueError("You should only use positional arguments or keyword arguments, not mixed!") | |
elif args: | |
to_send_dict["params"] = args | |
elif kwargs: | |
to_send_dict["params"] = kwargs | |
if not notice: | |
id = self.__generate_seq_id() | |
to_send_dict["id"] = id | |
await self.__websocket_connecting_future | |
await self.__websocket.send(json.dumps(to_send_dict)) | |
if not notice: | |
return await self.__fetch_remote_reply(id) | |
return wrapper | |
return decorator | |
def local_api(self, func: Callable): | |
# 函数装饰器,定义开放给对端的本地 API | |
method = func.__name__ | |
self.__local_api_bind[method] = func | |
return func | |
def on_connected(self, async_func): | |
self.__connected_hook = async_func | |
return async_func | |
def run_forever(self): | |
# 执行主函数 | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
self.__websocket_connecting_future = loop.create_future() | |
asyncio.run(self.__main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment