Skip to content

Instantly share code, notes, and snippets.

@dankrause
Last active March 10, 2023 18:09
import builtins
import collections
import contextlib
import copy
import dis
import json
import socket
import struct
import types
_COMPILER_FLAG_MAP = {val: key for key, val in dis.COMPILER_FLAG_NAMES.items()}
_HEADER_STRUCT = struct.Struct('!L')
_ANNOTATION_MAP = {
i: getattr(builtins, i)
for i in dir(builtins)
if type(getattr(builtins, i)) is type and i[0].islower()
}
def _get_annotation_type(type_name):
return _ANNOTATION_MAP[type_name] if type_name in _ANNOTATION_MAP else type_name
class RPCClientException(Exception):
pass
class RPCClientNamespace:
def __init__(self, sock, manifest=None, name=None, func=None):
self._name = name or []
self._sock = sock
self._manifest = manifest or self._send_raw_message("manifest")
self._func = func
for name, value in self._manifest.items():
item_type = value.get("_meta", {}).get("type", None)
if name == "_meta" or item_type is None:
continue
if item_type == "func":
setattr(self, name, self._deserialize_function(value))
elif item_type == "manifest":
namespace = self._build_nested_namespace(sock, value, [*self._name, name])
setattr(self, name, namespace)
def _send_to_socket(self, msg):
msg_bytes = _HEADER_STRUCT.pack(len(msg)) + bytes(msg, "utf-8")
self._sock.sendall(msg_bytes)
header = self._sock.recv(_HEADER_STRUCT.size)
(msg_len,) = _HEADER_STRUCT.unpack(header)
return str(self._sock.recv(msg_len), "utf-8")
def _send_raw_message(self, msg):
raw_response = self._send_to_socket(json.dumps(msg))
try:
response = json.loads(raw_response)
except json.JSONDecodeError:
response = {"error": {"msg": "Response is not valid json", "invalid_response": raw_response}}
if "response" in response:
return response["response"]
elif "manifest" in response:
return RPCClientNamespace(self._sock, response["manifest"], func=response["func"])
else:
raise RPCClientException(response.get("error", "unknown error"))
def _dispatch(self, args_dict, func_def):
args = []
kwargs = {}
if "params" in func_def:
for name, param in func_def["params"].items():
if name in args_dict:
if param["kind"] == "POSITIONAL_ONLY":
args.append(args_dict[name])
else:
kwargs[name] = args_dict[name]
msg = {
"func": [*self._name, func_def["name"]],
"args": args,
"kwargs": kwargs
}
if self._func:
msg["func"] = [*self._func, *msg["func"]]
return self._send_raw_message(msg)
def _build_nested_namespace(self, _sock, _manifest, _name):
if "_call" in _manifest["_meta"]:
class CallableRPCClientNamespace(RPCClientNamespace):
pass
CallableRPCClientNamespace.__call__ = self._deserialize_function(_manifest["_meta"]["_call"], _has_self=True)
return CallableRPCClientNamespace(_sock, _manifest, _name)
else:
return RPCClientNamespace(_sock, _manifest, _name)
def _deserialize_function(self, _func_def, _has_self=False):
_func_def = copy.deepcopy(_func_def)
if _has_self:
self_arg = {"self": {"kind": "POSITIONAL_OR_KEYWORD"}}
if "params" not in _func_def:
_func_def["params"] = self_arg
else:
_func_def["params"] = {**self_arg, **_func_def["params"]}
if "params" in _func_def:
varnames = tuple(_func_def["params"].keys())
arg_counts = collections.Counter([val["kind"] for val in _func_def["params"].values()])
defaults = tuple(val["default"] for val in _func_def["params"].values() if "default" in val)
nlocals = len(_func_def["params"])
annotations = {
key: _get_annotation_type(val["annotation"])
for key, val in _func_def["params"].items()
if "annotation" in val
}
else:
varnames = ()
arg_counts = collections.Counter()
defaults = None
nlocals = 0
annotations = {}
def wrapper():
return self._dispatch(locals(), _func_def)
code_obj = wrapper.__code__
flags = code_obj.co_flags
flags += _COMPILER_FLAG_MAP["VARARGS"] if 'VAR_POSITIONAL' in arg_counts else 0
flags += _COMPILER_FLAG_MAP["VARKEYWORDS"] if 'VAR_KEYWORD' in arg_counts else 0
new_code_obj = types.CodeType(
arg_counts["POSITIONAL_OR_KEYWORD"],
arg_counts["POSITIONAL_ONLY"],
arg_counts["KEYWORD_ONLY"],
code_obj.co_nlocals + nlocals,
code_obj.co_stacksize,
flags,
code_obj.co_code,
code_obj.co_consts,
code_obj.co_names,
varnames,
code_obj.co_filename,
_func_def["name"],
code_obj.co_firstlineno,
code_obj.co_lnotab,
code_obj.co_freevars
)
modified = types.FunctionType(
new_code_obj,
wrapper.__globals__,
argdefs=defaults,
closure=wrapper.__closure__
)
wrapper.__name__ = _func_def["name"]
wrapper.__code__ = modified.__code__
wrapper.__defaults__ = defaults
wrapper.__annotations__ = annotations
if "return" in _func_def:
wrapper.__annotations__["return"] = _get_annotation_type(_func_def["return"])
if "doc" in _func_def:
wrapper.__doc__ = _func_def["doc"]
if _has_self:
del(_func_def["params"]["self"])
return wrapper
class RPCClient:
def __init__(self, address):
self._address = address
def get_session(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(self._address)
return RPCClientNamespace(sock)
@staticmethod
def close_session(session):
session._sock.close()
@contextlib.contextmanager
def session(self):
session = self.get_session()
yield session
self.close_session(session)
import collections
import functools
import inspect
import json
import socketserver
import struct
import types
_HEADER_STRUCT = struct.Struct('!L')
class RPCServerException(Exception):
pass
def serialize_function(func, name=None):
response = {
"_meta": {"type": "func"},
"name": name
}
if func.__doc__:
response["doc"] = func.__doc__
spec = inspect.signature(func)
if spec.return_annotation is not inspect._empty:
response["return"] = spec.return_annotation.__name__
if len(spec.parameters):
response["params"] = {}
for name, param in spec.parameters.items():
response["params"][name] = {}
for key in ("annotation", "default", "kind"):
val = getattr(param, key)
if val is not inspect._empty:
if isinstance(val, inspect._ParameterKind):
val = val.name
elif isinstance(val, type):
val = val.__name__
response["params"][name][key] = val
return response
class RPCServerNamespace:
def __init__(self, name, wrapped_obj=None):
self._registry = {"_meta":{"type": "manifest", "name": name}}
self._obj_registry = collections.defaultdict(dict)
self.wrapped_obj = wrapped_obj
def _call_wrapped(self, args, kwargs, _client_address):
if type(self.wrapped_obj) is type:
# class
return self._instance_init(self.wrapped_obj, args, kwargs, _client_address)
elif callable(self.wrapped_obj):
return self.wrapped_obj(*args, **kwargs)
else:
raise RPCServerException("Namespace not callable")
def _instance_init(self, func_obj, args, kwargs, _client_address):
obj = func_obj(*args, **kwargs)
instance_namespace = RPCServerNamespace(id(obj), wrapped_obj=obj)
for method_name in dir(obj):
method = getattr(obj, method_name)
if method is not type(obj):
if hasattr(method, "__func__"):
func = method.__func__
else:
func = method
if hasattr(func, "__name__"):
rpcname = getattr(func, "__rpcname__", func.__name__)
if not rpcname.startswith("_") and not hasattr(func, "__rpcignore__"):
instance_namespace.register(rpcname, func=method)
@instance_namespace.register()
def _delete():
o = obj
i = instance_namespace
del(self._obj_registry[_client_address][id(i)], i , o)
return True
return instance_namespace
def _build_class_namespace(self, cls):
class_namespace = RPCServerNamespace(cls.__name__, wrapped_obj=cls)
for name, method in cls.__dict__.items():
if isinstance(method, (classmethod, staticmethod)):
class_namespace.register(func=getattr(cls, name))
return class_namespace
def ignore(self, obj):
setattr(getattr(obj, "__func__", obj), "__rpcignore__", True)
return obj
def register(self, name=None, namespace=None, func=None):
if name is not None and not name[0].isalpha():
raise ValueError("Namespaces must start with a letter")
if func is None:
return functools.partial(self.register, name, namespace)
if namespace is not None:
namespace, remaining_path = self._get_next_namespace(namespace, create=True)
if namespace is not self:
return namespace.register(name, remaining_path, func)
if isinstance(func, types.BuiltinFunctionType):
raise TypeError("Built-in functions not supported")
elif isinstance(func, classmethod):
raise TypeError("Cannot register classmethods directly")
elif isinstance(func, staticmethod):
func.__func__.__rpcname__ = name or func.__func__.__name__
self._registry[func.__func__.__rpcname__] = func.__func__
elif type(func) is type:
# class
func.__rpcname__ = name or func.__name__
self._registry[func.__rpcname__] = self._build_class_namespace(func)
elif hasattr(func, "__qualname__"):
if hasattr(func, "__self__"):
# bound method
func.__func__.__rpcname__ = name or func.__func__.__name__
self._registry[func.__func__.__rpcname__] = func
else:
# unbound method or function
func.__rpcname__ = name or func.__name__
self._registry[func.__rpcname__] = func
elif callable(func) and name is not None:
# callable instance
func.__rpcname__ = name
self._registry[name] = func
return func
def register_instance(self, instance, name, methods=None):
wrapped_obj = instance if callable(instance) else None
instance_namespace = self.add_namespace(name, wrapped_obj)
if not methods:
methods = [
getattr(instance, method)
for method in dir(instance)
if isinstance(getattr(instance, method), types.MethodType)
and method[0].isalpha()
]
if hasattr(methods, "items"):
for name, method in methods.items():
instance_namespace.register(name, func=method)
else:
for method in methods:
instance_namespace.register(func=method)
return instance_namespace
def _get_next_namespace(self, path, create=False, client_address=None):
if path is None or not len(path):
return self, None
else:
if isinstance(path[0], int):
if path[0] in self._obj_registry[client_address]:
return self._obj_registry[client_address][path[0]], path[1:] or None
elif path[0] in self._registry:
if isinstance(self._registry[path[0]], RPCServerNamespace):
return self._registry[path[0]], path[1:] or None
else:
if len(path) == 1:
return self, path[0]
else:
if create:
namespace = RPCServerNamespace(path[0])
self._registry[path[0]] = namespace
return namespace, path[1:] or None
else:
if len(path) == 1:
return self, path[0]
raise RPCServerException(f"Invalid namespace: {path[0]}")
def call(self, func, args, kwargs, _client_address):
if func is not None:
namespace, func = self._get_next_namespace(func, client_address=_client_address)
if namespace is not self:
return namespace.call(func, args, kwargs, _client_address)
if func is None:
func_obj = self
else:
func_obj = self._registry[func]
if type(func_obj) is type:
ns = self._instance_init(func_obj, args, kwargs, _client_address)
result = ns
elif isinstance(func_obj, RPCServerNamespace):
result = func_obj._call_wrapped(args, kwargs, _client_address)
else:
result = func_obj(*args, **kwargs)
if isinstance(result, RPCServerNamespace):
self._obj_registry[_client_address][id(result)] = result
return result
def add_namespace(self, name, wrapped_obj=None):
new_ns = RPCServerNamespace(name, wrapped_obj)
self._registry[name] = new_ns
return new_ns
def remove_namespace(self, name):
if name not in self._registry:
raise RPCServerException("Invalid namespace name")
del(self._registry[name])
def clear_objects(self, client_address):
if client_address in self._obj_registry:
del(self._obj_registry[client_address])
for ns in self._registry.values():
if isinstance(ns, RPCServerNamespace):
ns.clear_objects(client_address)
@property
def manifest(self):
manifest = {"_meta": self._registry["_meta"]}
if callable(self.wrapped_obj):
manifest["_meta"]["_call"] = serialize_function(self.wrapped_obj, manifest["_meta"]["name"])
manifest.update({
key: serialize_function(val, key)
for key, val in self._registry.items()
if callable(val)
})
manifest.update({
key: val.manifest
for key, val in self._registry.items()
if isinstance(val, RPCServerNamespace)
})
return manifest
class RPCServer(socketserver.ThreadingTCPServer, RPCServerNamespace):
def __init__(self, address):
RPCServerNamespace.__init__(self, "_root")
class _RCPServerHandler(socketserver.StreamRequestHandler):
def handle(handler_self):
while True:
header = handler_self.request.recv(_HEADER_STRUCT.size)
if not len(header):
break
(msg_len,) = _HEADER_STRUCT.unpack(header)
response_json = self._handle_message(
handler_self.request.recv(msg_len),
handler_self.client_address
)
response = _HEADER_STRUCT.pack(len(response_json)) + response_json
handler_self.request.sendall(response)
self.clear_objects(handler_self.client_address)
socketserver.ThreadingTCPServer.__init__(self, address, _RCPServerHandler)
def _handle_message(self, raw_msg, client_address):
try:
msg = json.loads(str(raw_msg, 'utf-8'))
if msg == "manifest":
raw_response = self.manifest
else:
raw_response = self.call(**msg, _client_address=client_address)
if isinstance(raw_response, RPCServerNamespace):
return bytes(json.dumps({"manifest": raw_response.manifest, "func": [*msg['func'], id(raw_response)]}), 'utf-8')
else:
return bytes(json.dumps({"response": raw_response}), 'utf-8')
except Exception as e:
return bytes(json.dumps({"error": str(e)}), 'utf-8')
from server import RPCServer, RPCServerNamespace
from client import RPCClientNamespace, RPCClientException
address = ('127.0.0.1', 4512)
server = RPCServer(address)
dumb_store = {}
ns = server.add_namespace("x")
broken_ns = server.add_namespace("y")
@server.register()
def store(name: str, value: str) -> bool:
dumb_store[name] = value
return True
@ns.register("remove")
def delete(name: str) -> str:
if name in dumb_store:
val = dumb_store[name]
del(dumb_store[name])
return val
else:
return ""
@server.register()
def get(name: str) -> str:
if name in dumb_store:
return dumb_store[name]
else:
return ""
@ns.register()
def error():
raise Exception("error")
@server.register("existing_namespace", "x")
def bar():
return "bar"
@server.register(namespace="z")
def baz():
return "bar"
@broken_ns.register()
def broken():
pass
widget_registry = {}
@server.register()
class Widget:
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
def set_name(self, name):
self.name = name
@classmethod
def get_class_name(cls):
return cls.__name__
@server.register("return_true")
@staticmethod
def return_true():
return True
@server.ignore
def ignore_me(self):
pass
def _also_ignored(self):
pass
class CallMe:
def __init__(self, name):
self._name = name
def __call__(self, text):
return f"{self._name}: {text}"
def set_name(self, name):
self._name = name
def get_name(self):
return self._name
def _reverse_name(self):
return "".join(reversed(self._name))
call_me = CallMe("foo")
server.register("call_me", func=call_me)
call_me_1 = server.register_instance(call_me, "call_me_1")
call_me_2 = server.register_instance(call_me, "call_me_2", [call_me.get_name])
call_me_3 = server.register_instance(call_me, "call_me_3", {"get": call_me.get_name, "set": call_me.set_name, "reverse": call_me._reverse_name})
RPCClientNamespace._send_to_socket = lambda self, msg: server._handle_message(bytes(msg, 'utf-8'), "")
session = RPCClientNamespace(None)
server._registry.pop("y")
widget_foo = session.Widget("fiz")
if __name__ == "__main__":
assert session.store("foo", "bar") == True
assert session.get("foo") == "bar"
assert session.store("foo", "baz") == True
assert session.get("foo") == "baz"
assert session.x.remove("foo") == "baz"
assert session.x.existing_namespace() == "bar"
assert session.z.baz() == "bar"
widget_foo = session.Widget("foo")
widget_bar = session.Widget("bar")
widget_foo.set_name("Foo")
assert widget_foo.get_class_name() == "Widget"
assert session.Widget.get_class_name() == "Widget"
assert session.return_true() is True
assert session.Widget.return_true() is True
assert widget_foo.return_true() is True
assert widget_foo.get_name() == "Foo"
assert widget_bar.get_name() == "bar"
assert not hasattr(widget_foo, "ignore_me")
assert not hasattr(widget_foo, "_also_ignored")
assert session.call_me("bar") == "foo: bar"
assert session.call_me_1("bar") == "foo: bar"
session.call_me_1.set_name("baz")
assert session.call_me("bar") == "baz: bar"
assert session.call_me_1("bar") == "baz: bar"
assert session.call_me_2("bar") == "baz: bar"
assert session.call_me_3("bar") == "baz: bar"
assert not hasattr(session.call_me_2, "set_name")
assert not hasattr(session.call_me_3, "set_name")
assert session.call_me_2.get_name() == "baz"
assert session.call_me_3.get() == "baz"
assert session.call_me_3.reverse() == "zab"
session.call_me_3.set("baz 2")
assert session.call_me("bar") == "baz 2: bar"
assert session.call_me_1("bar") == "baz 2: bar"
assert session.call_me_2("bar") == "baz 2: bar"
assert session.call_me_3("bar") == "baz 2: bar"
widget_foo._delete()
try:
widget_foo.get_name()
assert False
except RPCClientException as e:
assert e.args[0].startswith("Invalid namespace: ")
try:
session.x.error()
assert False
except RPCClientException as e:
assert e.args[0] == "error"
try:
session.y.broken()
assert False
except RPCClientException as e:
assert e.args[0] == "Invalid namespace: y"
print("\nAll tests passed.\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment