Skip to content

Instantly share code, notes, and snippets.

@linepro6
Last active December 22, 2023 07:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save linepro6/f51ac8930882ce8200f8a0ae795c214e to your computer and use it in GitHub Desktop.
Save linepro6/f51ac8930882ce8200f8a0ae795c214e to your computer and use it in GitHub Desktop.
Python JsonRPC in WebSockets
class APIException(Exception):
def __init__(self, code: int, message: str, data: dict = None):
self.code = code
self.message = message
self.data = data
def __str__(self):
return f"{self.message} ({self.code})"
class ParseErrorException(APIException):
def __init__(self):
super().__init__(-32700, "Parse Error")
class InvalidRequestException(APIException):
def __init__(self):
super().__init__(-32600, "Invalid Request")
class MethodNotFoundException(APIException):
def __init__(self):
super().__init__(-32601, "Method Not Found")
class InvalidParamsException(APIException):
def __init__(self):
super().__init__(-32602, "Invalid Params")
class InternalErrorException(APIException):
def __init__(self, traceback: str):
super().__init__(-32603, "Internal Error", {"traceback": traceback})
import asyncio
import websockets.client as ws_client
import websockets
from websockets.exceptions import ConnectionClosed
from loguru import logger
from typing import Callable, Union, Optional
import json
import traceback
import inspect
from functools import wraps
from .exceptions import *
def get_default_args(func):
signature = inspect.signature(func)
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
}
class RaiseAPIException:
"""captures specified exception and raise ApiErrorCode instead
:raises: AttributeError if code_name is not valid
"""
def __init__(self, api_exception: APIException, *captures):
self.captures = captures
self.api_exception = api_exception
def __enter__(self):
# 该方法将在进入上下文时调用
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# 该方法将在退出上下文时调用
# exc_type, exc_val, exc_tb 分别表示该上下文内抛出的
# 异常类型、异常值、错误栈
if exc_type is None:
return False
if exc_type in self.captures:
raise self.api_exception from exc_val
return False
class JsonRPCAPIClient:
def __init__(self, uri: str):
self.__uri = uri
self.__seq = 0
self.__local_api_bind = {}
self.__remote_api_future = {}
self.__websocket: ws_client.WebSocketClientProtocol = None
self.__websocket_connecting_future = None
self.__connected_hook = None
async def __fetch_remote_reply(self, id: int):
future = asyncio.get_event_loop().create_future()
self.__remote_api_future[id] = future
try:
result = await asyncio.wait_for(future, 60)
if "error" in result:
try:
e = result["error"]
raise APIException(e["code"], e["message"], e.get("data", None))
except KeyError:
raise InternalErrorException(traceback.format_exc())
elif "result" in result:
result = result["result"]
else:
raise InternalErrorException(json.dumps(result))
return result
finally:
del self.__remote_api_future[id]
def __generate_seq_id(self):
s = self.__seq
self.__seq = (self.__seq + 1) % 9999999
return s
async def __send_error(self, e: APIException, id: int = None):
to_send_dict = {
"jsonrpc": "2.0",
"error": {
"code": e.code,
"message": e.message
},
"id": id
}
if e.data:
to_send_dict["error"]["data"] = e.data
await self.__websocket_connecting_future
await self.__websocket.send(json.dumps(to_send_dict))
async def __local_api_call(self, method: str, params: Optional[Union[list, dict]], id: Optional[int]):
try:
func = self.__local_api_bind.get(method, None)
if not func:
raise MethodNotFoundException()
with RaiseAPIException(InvalidParamsException, TypeError):
if asyncio.iscoroutinefunction(func):
if params is None:
result = await func()
elif isinstance(params, list):
result = await func(*params)
elif isinstance(params, dict):
result = await func(**params)
else:
raise TypeError
else:
if params is None:
result = func()
elif isinstance(params, list):
result = func(*params)
elif isinstance(params, dict):
result = func(**params)
else:
raise TypeError
if id is not None:
await self.__websocket_connecting_future
await self.__websocket.send(json.dumps({
"jsonrpc": "2.0",
"result": result,
"id": id
}))
except APIException as e:
await self.__send_error(e, id)
except Exception as e:
trace_info = traceback.format_exc()
logger.error(trace_info)
await self.__send_error(InternalErrorException(trace_info), id)
async def __main(self):
logger.info("[JSONRPC] Connecting to server " + self.__uri)
async for websocket in ws_client.connect(self.__uri):
logger.info("[JSONRPC] Connected to server " + self.__uri)
self.__websocket = websocket
self.__websocket_connecting_future.set_result(None)
if self.__connected_hook:
asyncio.create_task(self.__connected_hook())
try:
# Process messages received on the connection.
async for message in websocket:
id = None
try:
with RaiseAPIException(ParseErrorException, json.JSONDecodeError):
data: dict = json.loads(message)
id = data.get("id", None)
if "method" in data:
with RaiseAPIException(InvalidRequestException, KeyError):
# request
method = data["method"]
if not isinstance(method, str):
raise KeyError(method)
params = data.get("params", None)
asyncio.create_task(self.__local_api_call(method, params, id))
elif "result" in data or "error" in data:
if id is None:
logger.warning("[JSONRPC] received malform data " + message)
continue
future = self.__remote_api_future.get(id, None)
if future:
future.set_result(data)
else:
logger.warning("[JSONRPC] received malform data " + message)
else:
logger.warning("[JSONRPC] received malform data " + message)
except APIException as e:
await self.__send_error(e, id)
except ConnectionClosed:
logger.warning("[JSONRPC] Connection closed, reconnect in 5s...")
self.__websocket_connecting_future = asyncio.get_event_loop().create_future()
await asyncio.sleep(5)
def remote_api(self, notice=False):
"""函数装饰器,定义远程对端 API,函数体为 pass 即可(不会执行,纯定义)
注意:传参时只能选择位置参数(数组)或具名参数(字典)其中一种传参形式
Args:
notice (bool, optional): 是否仅通知,即不需要对端返回结果,默认为否
"""
def decorator(func):
async def wrapper(*args, **kwargs):
method = func.__name__
to_send_dict = {
"jsonrpc": "2.0",
"method": method
}
default_kwargs = get_default_args(func)
default_kwargs.update(kwargs)
kwargs = default_kwargs
if args and kwargs:
raise ValueError("You should only use positional arguments or keyword arguments, not mixed!")
elif args:
to_send_dict["params"] = args
elif kwargs:
to_send_dict["params"] = kwargs
if not notice:
id = self.__generate_seq_id()
to_send_dict["id"] = id
await self.__websocket_connecting_future
await self.__websocket.send(json.dumps(to_send_dict))
if not notice:
return await self.__fetch_remote_reply(id)
return wrapper
return decorator
def local_api(self, func: Callable):
# 函数装饰器,定义开放给对端的本地 API
method = func.__name__
self.__local_api_bind[method] = func
return func
def on_connected(self, async_func):
self.__connected_hook = async_func
return async_func
def run_forever(self):
# 执行主函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.__websocket_connecting_future = loop.create_future()
asyncio.run(self.__main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment