Skip to content

Instantly share code, notes, and snippets.

@jrialland
Last active April 3, 2023 17:17
Show Gist options
  • Save jrialland/0253606b71736abef9cd6c4cd1c8909e to your computer and use it in GitHub Desktop.
Save jrialland/0253606b71736abef9cd6c4cd1c8909e to your computer and use it in GitHub Desktop.
package com.demo.rpc;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Provides a way for doing rpc (remote procedure call) that is
* transport-agnostic.
* Users of this class shall provide a way :
* <ul>
* <li>for sending messages (an instance of {@link MessageSender}</li>
* <li>for receiving message (an instance of {@link MessageProvider}</li>
* <li>for locating the implementation for services (@link
* {@link ImplementationProvider})</li>
* </ul>
*/
public class RpcHelper implements Closeable {
public static final String CODE = "code";
public static final String MESSAGE = "message";
public static final String ID = "id";
public static final String JSONRPC = "jsonrpc";
public static final String METHOD = "method";
public static final String PARAMS = "params";
public static final String ERROR = "error";
private static final int DEFAULT_TIMEOUT_SECONDS = 7;
private static final int DEFAULT_DELAY_BETWEEN_NOTIFICATIONS_SECONDS = 5;
private static final Logger logger = LoggerFactory.getLogger(RpcHelper.class);
/**
* incremented each time a rpc request is done
*/
private static final AtomicInteger counter = new AtomicInteger(0);
private static final Pattern methodNamePattern = Pattern.compile("(^.*)\\.(.+)$");
/**
* used to create multithreaded tasks
*/
private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactory() {
private final ThreadFactory defaultTf = Executors.defaultThreadFactory();
@Override
public Thread newThread(Runnable r) {
Thread t = defaultTf.newThread(r);
t.setDaemon(true);
return t;
}
});
/**
* caches the proxy objects that have been already created
*/
private final Map<String, Object> proxies = new HashMap<>();
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* handles message sending
*/
private final MessageSender messageSender;
/**
* how to retrieve the real objects that are called
*/
private final ImplementationProvider implementationProvider;
/**
* reference to the thread that uses the messageProvider to get incoming rpc
* requests
*/
private final Future<?> readTask;
/**
* caches the methods
*/
private final Map<String, Method> methods = new HashMap<>();
/**
* stores the handlers that wait for the rpc responses
*/
private final Map<Integer, ResponseHandler> responseHandlers = new ConcurrentHashMap<>();
private final int delayBetweenNotificationsSeconds = DEFAULT_DELAY_BETWEEN_NOTIFICATIONS_SECONDS;
/**
* @param implementationProvider an object that is responsible for finding the
* implementation for the called objects
* @param messageSender describes how to send message across a network,
* so we can reach the called object
* @param messageProvider this method will handle incoming request, its
* readMessage() will be called in a loop
*/
@SuppressWarnings("BusyWait")
public RpcHelper(ImplementationProvider implementationProvider, final MessageSender messageSender,
final MessageProvider messageProvider) {
this.implementationProvider = implementationProvider;
this.messageSender = messageSender;
readTask = executorService.submit(() -> {
while (true) {
Integer id = null;
String from = null;
try {
final IncomingMessage incomingMessage;
try {
incomingMessage = messageProvider.readMessage();
} catch (InterruptedException e) {
return;
}
if (incomingMessage == null) {
logger.warn("MessageProvider should be blocking and not return null messages.");
Thread.sleep(TimeUnit.SECONDS.toMillis(delayBetweenNotificationsSeconds));
break;
}
logger.debug(String.format("Received message from %s : %s", incomingMessage.getSender(),
incomingMessage.getPayload()));
from = incomingMessage.getSender();
final String finalFrom = from;
final JsonNode jsonMessage = objectMapper.readTree(incomingMessage.getPayload());
Iterator<JsonNode> it = jsonMessage.isArray() ? jsonMessage.iterator()
: List.of(jsonMessage).iterator();
while (it.hasNext()) {
JsonNode item = it.next();
// this a json rpc message
if (item.has(JSONRPC) && item.get(JSONRPC).asText().equals("2.0")) {
// this is a request or a response
if (item.has(ID)) {
id = item.get(ID).asInt();
// this is a request
if (item.has(METHOD)) {
final String methodName = item.get(METHOD).asText();
final Object impl = findImpl(methodName);
if (impl == null) {
logger.error(String.format("[%s] service not found", from));
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND,
"Method not found");
}
final ArrayNode params;
if (item.has(PARAMS) && item.get(PARAMS).isArray()) {
params = (ArrayNode) item.get(PARAMS);
} else {
logger.warn(String.format("[%s] 'params' is not an array", from));
params = objectMapper.createArrayNode();
}
final Method method;
try {
method = findMethod(methodName, impl, params.size());
} catch (Exception e) {
logger.error(String.format("[%s] method not found", from));
throw e;
}
final Integer fId = id;
final AtomicBoolean done = new AtomicBoolean(false);
// while the invocation is not finished, send notifications to the caller, so we
// prevent timeouts
executorService.submit(() -> {
long start = System.currentTimeMillis();
while (true) {
try {
// noinspection BusyWait
Thread.sleep(
TimeUnit.SECONDS.toMillis(delayBetweenNotificationsSeconds));
} catch (InterruptedException e) {
return;
}
if (!done.get()) {
// send notification
ObjectNode notification = objectMapper.createObjectNode();
notification.put(JSONRPC, "2.0");
notification.put(METHOD, "notify.progress");
ObjectNode data = objectMapper.createObjectNode();
data.put(ID, fId);
data.put(MESSAGE, "invocation in progress");
data.put(METHOD, methodName);
data.put("elapsed", System.currentTimeMillis() - start);
notification.set(PARAMS, data);
try {
String payload = objectMapper.writeValueAsString(notification);
logger.debug(String.format("Sending notification to %s: %s",
finalFrom, payload));
messageSender.sendMessage(finalFrom, payload);
} catch (IOException e) {
logger.error("notification failed", e);
}
} else {
return;
}
}
});
// call the service
executorService.submit(() -> {
try {
doInvoke(fId, finalFrom, impl, method, params);
} finally {
done.set(true);
}
});
}
// the message is a response
else if (item.has("result")) {
ResponseHandler responseHandler = responseHandlers.remove(id);
if (responseHandler != null) {
responseHandler.onResult(item);
}
}
// the message is an error
else if (item.has(CODE) && item.has(MESSAGE)) {
ResponseHandler responseHandler = responseHandlers.remove(id);
if (responseHandler != null) {
responseHandler.onError(item);
}
}
}
// The message is a notification
else if (item.has(PARAMS)) {
id = item.get(PARAMS).get(ID).asInt();
ResponseHandler responseHandler = responseHandlers.get(id);
if (responseHandler != null) {
responseHandler.keepAlive();
}
}
}
}
} catch (JsonRpcException e) {
sendError(from, id, e.getCode(), e.getMessage(), e.getCause());
} catch (Exception e) {
sendError(from, id, JsonRpcErrorCodes.ERR_INTERNAL_ERROR, "Internal error", e);
}
}
});
}
private static Throwable findRoot(Throwable t) {
Set<Throwable> seen = new HashSet<>();
while (t.getCause() != null && !seen.contains(t.getCause())) {
seen.add(t);
if (t instanceof InvocationTargetException) {
t = ((InvocationTargetException) t).getTargetException();
} else {
t = t.getCause();
}
}
return t;
}
public int getTimeoutSeconds() {
return DEFAULT_TIMEOUT_SECONDS;
}
private void sendError(String respondTo, Integer id, int code, String message, Throwable cause) {
logger.error(String.format("[%s] jsonrpc error %d", respondTo, code), cause);
if (respondTo != null) {
try {
ObjectNode response = objectMapper.createObjectNode();
response.put(JSONRPC, "2.0");
if (id != null) {
response.put(ID, id);
}
ObjectNode error = objectMapper.createObjectNode();
error.put(CODE, code);
error.put(MESSAGE, message);
if (cause != null) {
cause = findRoot(cause);
String exceptionMessage = cause.getMessage();
if (exceptionMessage != null) {
error.put("cause", exceptionMessage);
}
StackTraceElement[] stackTrace = cause.getStackTrace();
if (stackTrace != null) {
ArrayNode backtraceNode = objectMapper.createArrayNode();
for (StackTraceElement stackTraceElement : stackTrace) {
int line = stackTraceElement.getLineNumber();
if (line < 1) {
break;
}
String filename = stackTraceElement.getFileName();
if (filename == null) {
break;
}
String methodName = stackTraceElement.getMethodName();
String className = stackTraceElement.getClassName();
methodName = className + "." + methodName;
ObjectNode stackTraceElementNode = objectMapper.createObjectNode();
stackTraceElementNode.put("file", filename);
stackTraceElementNode.put("line", line);
stackTraceElementNode.put("method", methodName);
backtraceNode.add(stackTraceElementNode);
}
error.set("backtrace", backtraceNode);
}
}
response.set(ERROR, error);
messageSender.sendMessage(respondTo, objectMapper.writeValueAsString(response));
} catch (Exception e) {
logger.error("could note send message", e);
}
}
}
private void doInvoke(int id, String respondTo, Object impl, Method method, ArrayNode params) {
Object[] args = new Object[params.size()];
final Type[] parameterTypes = method.getGenericParameterTypes();
try {
for (int i = 0, max = Math.min(params.size(), parameterTypes.length); i < max; i++) {
JavaType javaType = objectMapper.getTypeFactory().constructType(parameterTypes[i]);
args[i] = objectMapper.treeToValue(params.get(i), javaType);
}
} catch (Exception e) {
sendError(respondTo, id, JsonRpcErrorCodes.ERR_INVALID_PARAMS, "Invalid params", e);
return;
}
try {
Object result = method.invoke(impl, args);
ObjectNode response = objectMapper.createObjectNode();
response.put(JSONRPC, "2.0");
response.put(ID, id);
response.set("result", objectMapper.valueToTree(result));
String payload = objectMapper.writer().writeValueAsString(response);
logger.debug(String.format("Sending message to %s : %s", respondTo, payload));
messageSender.sendMessage(respondTo, payload);
} catch (Throwable e) {
sendError(respondTo, id, JsonRpcErrorCodes.ERR_SERVER_ERROR, "invocation failed", e);
}
}
private Object findImpl(String methodSpec) {
Matcher m = methodNamePattern.matcher(methodSpec);
if (m.matches()) {
String serviceName = m.group(1);
return implementationProvider.getImpl(serviceName);
}
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND, "Method not found");
}
private Method findMethod(String methodSpec, Object impl, int paramsCount) {
final String key = methodSpec + "/" + paramsCount;
return methods.computeIfAbsent(key, newKey -> {
Matcher m = methodNamePattern.matcher(methodSpec);
if (m.matches()) {
final String methodName = m.group(2);
for (Method candidate : impl.getClass().getMethods()) {
if (candidate.getName().equals(methodName) && candidate.getParameterTypes().length == paramsCount) {
return candidate;
}
}
}
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND, "Method not found");
});
}
@Override
public void close() {
readTask.cancel(true);
}
@SuppressWarnings({"unchecked", "unused"})
public <T> T createProxy(final String recipient, final String serviceName, Class<T> iface) {
if (!iface.isInterface()) {
throw new IllegalArgumentException("iface must be an interface");
}
final String key = serviceName + "@" + recipient;
Object impl;
if ((impl = proxies.get(key)) != null) {
return (T) impl;
}
InvocationHandler invocationHandler = new InvocationHandler() {
final Object local = new Object() {
@Override
public String toString() {
return "proxy of '" + serviceName + "@" + recipient + "'";
}
};
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Class<?> declaringClass = method.getDeclaringClass();
if (declaringClass.equals(Object.class) || declaringClass.equals(local.getClass())) {
return method.invoke(local, args);
}
int id = counter.incrementAndGet();
ObjectNode request = objectMapper.createObjectNode();
request.put(JSONRPC, "2.0");
request.put(ID, id);
request.put(METHOD, serviceName + "." + method.getName());
ArrayNode params = objectMapper.createArrayNode();
for (Object arg : args) {
params.add(objectMapper.valueToTree(arg));
}
request.set(PARAMS, params);
final ResponseHandler responseHandler = new ResponseHandler(method.getReturnType());
responseHandlers.put(id, responseHandler);
messageSender.sendMessage(recipient, objectMapper.writeValueAsString(request));
try {
return responseHandler.waitFor();
} finally {
responseHandlers.remove(id);
}
}
};
impl = Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class<?>[]{iface},
invocationHandler);
proxies.put(key, impl);
return (T) impl;
}
private interface JsonRpcErrorCodes {
int ERR_METHOD_NOT_FOUND = -32601;
int ERR_INVALID_PARAMS = -32602;
int ERR_INTERNAL_ERROR = -32603;
int ERR_SERVER_ERROR = -32000;
}
public interface IncomingMessage {
String getSender();
String getPayload();
}
/**
* contract for implementing how to send messages to a given recipient
*/
public interface MessageSender {
/**
* @param recipient the name of the recipient
* @param payload the message to be sent
*/
void sendMessage(String recipient, String payload) throws IOException;
}
public interface MessageProvider {
/**
* waits for incoming message (blocking)
*
* @return the next message that we receive
*/
IncomingMessage readMessage() throws InterruptedException;
}
/**
* interface that is responsible for providing implementations for the required
* service name.
*/
public interface ImplementationProvider {
Object getImpl(String serviceName);
}
private static class JsonRpcException extends RuntimeException {
private final int code;
public JsonRpcException(int code, String message) {
super(message);
this.code = code;
}
public int getCode() {
return code;
}
}
public static class SimpleIncomingMessage implements IncomingMessage {
private final String sender;
private final String payload;
public SimpleIncomingMessage(String sender, String payload) {
this.sender = sender;
this.payload = payload;
}
@Override
public String getSender() {
return sender;
}
@Override
public String getPayload() {
return payload;
}
}
class ResponseHandler {
private final Semaphore sem = new Semaphore(0);
private final Class<?> returnType;
boolean isDone = false;
private Object result;
private JsonRpcException exception;
public ResponseHandler(Class<?> returnType) {
this.returnType = returnType;
}
void onResult(JsonNode result) {
if (!this.returnType.equals(Void.TYPE)) {
try {
this.result = objectMapper.treeToValue(result, returnType);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
isDone = true;
sem.release();
}
void onError(JsonNode result) {
this.exception = new JsonRpcException(result.get(CODE).asInt(), result.get(MESSAGE).asText());
isDone = true;
sem.release();
}
void keepAlive() {
sem.release();
}
Object waitFor() throws Exception {
final int t = getTimeoutSeconds();
boolean waiting = true;
while (waiting) {
if (sem.tryAcquire(t, TimeUnit.SECONDS)) {
if (isDone) {
if (this.exception != null) {
throw this.exception;
} else {
return this.result;
}
}
} else {
waiting = false;
}
}
throw new TimeoutException(String.format("No response from remote after %d seconds", t));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment