Skip to content

Instantly share code, notes, and snippets.

@jrialland
Created December 17, 2021 08:55
Show Gist options
  • Save jrialland/28c379bfec781a658f56ff4412253926 to your computer and use it in GitHub Desktop.
Save jrialland/28c379bfec781a658f56ff4412253926 to your computer and use it in GitHub Desktop.
import sys, traceback, logging, json
from flask import Flask, request, Response, url_for
class JsonRpc:
def __init__(self, server, save_js_to=None):
self.server = server
self.endpoints = {}
self.services = {}
self.save_js_to = save_js_to
self.js = None
server.route("/jsonrpc/")(JsonRpc._base)
server.route("/jsonrpc/<endpoint>", methods=["GET", "POST"])(
lambda endpoint: self._handle_jsonrpc_request(endpoint)
)
@staticmethod
def _base():
return ""
def _handle_jsonrpc_request(self, endpoint):
if endpoint == "api.js":
return self._handle_jsonrpc_js()
try:
payload = request.get_json()
assert payload["jsonrpc"] == "2.0"
assert "id" in payload
assert payload["method"] == endpoint
params = payload["params"] if "params" in payload else []
if type(params) == dict:
result = self.endpoints[endpoint](**params)
else:
result = self.endpoints[endpoint](*params)
return Response(
json.dumps({"jsonrpc": "2.0", "id": payload["id"], "result": result}),
mimetype="application/json",
)
except Exception as e:
logging.exception("json rpc failure")
exc_type, exc_value, exc_traceback = sys.exc_info()
resp = {
"jsonrpc": "2.0",
"error": {
"code": -32000,
"message": str(e),
"data": traceback.format_exception(
exc_type, exc_value, exc_traceback
),
},
}
if payload and "id" in payload:
resp["id"] = payload["id"]
return Response(json.dumps(resp), status=500, mimetype="application/json")
def _generate_js(self):
js = "("
js += "function(_scope) {\n"
js += """
let _counter = 0;
const _rpc_call = (endpoint, args) => {
const fetchArgs = {
method : 'POST',
headers : {
'Content-Type' : 'application/json',
'Accept': 'application/json'
},
body : JSON.stringify({
jsonrpc : '2.0',
method : endpoint,
params : args,
id : _counter++
})
};
return new Promise((accept, reject) => {
fetch( '@JSONRPC_ROOT_PATH@' + endpoint, fetchArgs)
.then(response => {
if(response.status == 200) {;
response.json().then(data => {
if(data.result) {
accept(data.result)
} else {
console.error(data.error);
reject(new Error(data.error.message));
}
});
} else {
response.json().then(data => {
console.error(data.error);
reject(new Error(data.error.message));
});
}
});
});
};
_scope=_scope||{};
""".replace('@JSONRPC_ROOT_PATH@', url_for(JsonRpc._base.__name__))
def make_call(name):
return (
'function() { return _rpc_call("'
+ name
+ '", Array.from(arguments)); }'
)
for service_name, methods in self.services.items():
varname = f'_scope["{service_name}"]'
js += f'\t_scope["{service_name}"] = ' + "{};\n"
js += f'\t_scope["{service_name}"]["toString"] = () => "{service_name}";\n'
for method_name in methods:
endpoint_name = f"{service_name}.{method_name}"
js += (
"\t"
+ f'_scope["{service_name}"]["{method_name}"] = '
+ make_call(endpoint_name)
+ "\n"
)
for fn_name in self.endpoints.keys():
if not "." in fn_name:
js += f'\t_scope["{fn_name}"]' + " = " + make_call(fn_name) + "\n"
js += "}" + f"({request.args.get('scope', 'window')})"
js += ");"
return js
def _handle_jsonrpc_js(self):
if not self.js:
self.js = self._generate_js()
if self.save_js_to:
try:
with open(self.save_js_to) as f:
f.write(self.js)
logging.info(f"jsonrpc javascript api saved to {self.save_js_to}")
self.save_js_to = None
except Exception as e:
logging.exception("error while saving jsonrpc javascript api")
return Response(self._generate_js(), mimetype="text/javascript")
def _decorate(self, fn, key):
self.endpoints[key] = fn
logging.info(f"registered jsonrpc endpoint '{key}' -> {repr(fn)}")
return fn
def endpoint(self, *args, **kwargs):
if type(args[0]) == str:
return lambda fn: self._decorate(fn, args[0])
else:
return self._decorate(args[0], args[0].__name__)
def register_service(self, service_name, impl):
methods = set([])
object_methods = [
method_name
for method_name in dir(object)
if callable(getattr(object, method_name))
]
for method_name in dir(impl):
if not method_name.startswith("__"):
method = getattr(impl, method_name)
if callable(method):
methods.add(method_name)
endpoint_name = service_name + "." + method_name
self.endpoint(endpoint_name)(method)
self.services[service_name] = methods
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment