Last active
April 3, 2023 17:17
-
-
Save jrialland/0253606b71736abef9cd6c4cd1c8909e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.demo.rpc; | |
import com.fasterxml.jackson.databind.JavaType; | |
import com.fasterxml.jackson.databind.JsonNode; | |
import com.fasterxml.jackson.databind.ObjectMapper; | |
import com.fasterxml.jackson.databind.node.ArrayNode; | |
import com.fasterxml.jackson.databind.node.ObjectNode; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.Closeable; | |
import java.io.IOException; | |
import java.lang.reflect.*; | |
import java.util.*; | |
import java.util.concurrent.*; | |
import java.util.concurrent.atomic.AtomicBoolean; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.regex.Matcher; | |
import java.util.regex.Pattern; | |
/** | |
* Provides a way for doing rpc (remote procedure call) that is | |
* transport-agnostic. | |
* Users of this class shall provide a way : | |
* <ul> | |
* <li>for sending messages (an instance of {@link MessageSender}</li> | |
* <li>for receiving message (an instance of {@link MessageProvider}</li> | |
* <li>for locating the implementation for services (@link | |
* {@link ImplementationProvider})</li> | |
* </ul> | |
*/ | |
public class RpcHelper implements Closeable { | |
public static final String CODE = "code"; | |
public static final String MESSAGE = "message"; | |
public static final String ID = "id"; | |
public static final String JSONRPC = "jsonrpc"; | |
public static final String METHOD = "method"; | |
public static final String PARAMS = "params"; | |
public static final String ERROR = "error"; | |
private static final int DEFAULT_TIMEOUT_SECONDS = 7; | |
private static final int DEFAULT_DELAY_BETWEEN_NOTIFICATIONS_SECONDS = 5; | |
private static final Logger logger = LoggerFactory.getLogger(RpcHelper.class); | |
/** | |
* incremented each time a rpc request is done | |
*/ | |
private static final AtomicInteger counter = new AtomicInteger(0); | |
private static final Pattern methodNamePattern = Pattern.compile("(^.*)\\.(.+)$"); | |
/** | |
* used to create multithreaded tasks | |
*/ | |
private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactory() { | |
private final ThreadFactory defaultTf = Executors.defaultThreadFactory(); | |
@Override | |
public Thread newThread(Runnable r) { | |
Thread t = defaultTf.newThread(r); | |
t.setDaemon(true); | |
return t; | |
} | |
}); | |
/** | |
* caches the proxy objects that have been already created | |
*/ | |
private final Map<String, Object> proxies = new HashMap<>(); | |
private final ObjectMapper objectMapper = new ObjectMapper(); | |
/** | |
* handles message sending | |
*/ | |
private final MessageSender messageSender; | |
/** | |
* how to retrieve the real objects that are called | |
*/ | |
private final ImplementationProvider implementationProvider; | |
/** | |
* reference to the thread that uses the messageProvider to get incoming rpc | |
* requests | |
*/ | |
private final Future<?> readTask; | |
/** | |
* caches the methods | |
*/ | |
private final Map<String, Method> methods = new HashMap<>(); | |
/** | |
* stores the handlers that wait for the rpc responses | |
*/ | |
private final Map<Integer, ResponseHandler> responseHandlers = new ConcurrentHashMap<>(); | |
private final int delayBetweenNotificationsSeconds = DEFAULT_DELAY_BETWEEN_NOTIFICATIONS_SECONDS; | |
/** | |
* @param implementationProvider an object that is responsible for finding the | |
* implementation for the called objects | |
* @param messageSender describes how to send message across a network, | |
* so we can reach the called object | |
* @param messageProvider this method will handle incoming request, its | |
* readMessage() will be called in a loop | |
*/ | |
@SuppressWarnings("BusyWait") | |
public RpcHelper(ImplementationProvider implementationProvider, final MessageSender messageSender, | |
final MessageProvider messageProvider) { | |
this.implementationProvider = implementationProvider; | |
this.messageSender = messageSender; | |
readTask = executorService.submit(() -> { | |
while (true) { | |
Integer id = null; | |
String from = null; | |
try { | |
final IncomingMessage incomingMessage; | |
try { | |
incomingMessage = messageProvider.readMessage(); | |
} catch (InterruptedException e) { | |
return; | |
} | |
if (incomingMessage == null) { | |
logger.warn("MessageProvider should be blocking and not return null messages."); | |
Thread.sleep(TimeUnit.SECONDS.toMillis(delayBetweenNotificationsSeconds)); | |
break; | |
} | |
logger.debug(String.format("Received message from %s : %s", incomingMessage.getSender(), | |
incomingMessage.getPayload())); | |
from = incomingMessage.getSender(); | |
final String finalFrom = from; | |
final JsonNode jsonMessage = objectMapper.readTree(incomingMessage.getPayload()); | |
Iterator<JsonNode> it = jsonMessage.isArray() ? jsonMessage.iterator() | |
: List.of(jsonMessage).iterator(); | |
while (it.hasNext()) { | |
JsonNode item = it.next(); | |
// this a json rpc message | |
if (item.has(JSONRPC) && item.get(JSONRPC).asText().equals("2.0")) { | |
// this is a request or a response | |
if (item.has(ID)) { | |
id = item.get(ID).asInt(); | |
// this is a request | |
if (item.has(METHOD)) { | |
final String methodName = item.get(METHOD).asText(); | |
final Object impl = findImpl(methodName); | |
if (impl == null) { | |
logger.error(String.format("[%s] service not found", from)); | |
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND, | |
"Method not found"); | |
} | |
final ArrayNode params; | |
if (item.has(PARAMS) && item.get(PARAMS).isArray()) { | |
params = (ArrayNode) item.get(PARAMS); | |
} else { | |
logger.warn(String.format("[%s] 'params' is not an array", from)); | |
params = objectMapper.createArrayNode(); | |
} | |
final Method method; | |
try { | |
method = findMethod(methodName, impl, params.size()); | |
} catch (Exception e) { | |
logger.error(String.format("[%s] method not found", from)); | |
throw e; | |
} | |
final Integer fId = id; | |
final AtomicBoolean done = new AtomicBoolean(false); | |
// while the invocation is not finished, send notifications to the caller, so we | |
// prevent timeouts | |
executorService.submit(() -> { | |
long start = System.currentTimeMillis(); | |
while (true) { | |
try { | |
// noinspection BusyWait | |
Thread.sleep( | |
TimeUnit.SECONDS.toMillis(delayBetweenNotificationsSeconds)); | |
} catch (InterruptedException e) { | |
return; | |
} | |
if (!done.get()) { | |
// send notification | |
ObjectNode notification = objectMapper.createObjectNode(); | |
notification.put(JSONRPC, "2.0"); | |
notification.put(METHOD, "notify.progress"); | |
ObjectNode data = objectMapper.createObjectNode(); | |
data.put(ID, fId); | |
data.put(MESSAGE, "invocation in progress"); | |
data.put(METHOD, methodName); | |
data.put("elapsed", System.currentTimeMillis() - start); | |
notification.set(PARAMS, data); | |
try { | |
String payload = objectMapper.writeValueAsString(notification); | |
logger.debug(String.format("Sending notification to %s: %s", | |
finalFrom, payload)); | |
messageSender.sendMessage(finalFrom, payload); | |
} catch (IOException e) { | |
logger.error("notification failed", e); | |
} | |
} else { | |
return; | |
} | |
} | |
}); | |
// call the service | |
executorService.submit(() -> { | |
try { | |
doInvoke(fId, finalFrom, impl, method, params); | |
} finally { | |
done.set(true); | |
} | |
}); | |
} | |
// the message is a response | |
else if (item.has("result")) { | |
ResponseHandler responseHandler = responseHandlers.remove(id); | |
if (responseHandler != null) { | |
responseHandler.onResult(item); | |
} | |
} | |
// the message is an error | |
else if (item.has(CODE) && item.has(MESSAGE)) { | |
ResponseHandler responseHandler = responseHandlers.remove(id); | |
if (responseHandler != null) { | |
responseHandler.onError(item); | |
} | |
} | |
} | |
// The message is a notification | |
else if (item.has(PARAMS)) { | |
id = item.get(PARAMS).get(ID).asInt(); | |
ResponseHandler responseHandler = responseHandlers.get(id); | |
if (responseHandler != null) { | |
responseHandler.keepAlive(); | |
} | |
} | |
} | |
} | |
} catch (JsonRpcException e) { | |
sendError(from, id, e.getCode(), e.getMessage(), e.getCause()); | |
} catch (Exception e) { | |
sendError(from, id, JsonRpcErrorCodes.ERR_INTERNAL_ERROR, "Internal error", e); | |
} | |
} | |
}); | |
} | |
private static Throwable findRoot(Throwable t) { | |
Set<Throwable> seen = new HashSet<>(); | |
while (t.getCause() != null && !seen.contains(t.getCause())) { | |
seen.add(t); | |
if (t instanceof InvocationTargetException) { | |
t = ((InvocationTargetException) t).getTargetException(); | |
} else { | |
t = t.getCause(); | |
} | |
} | |
return t; | |
} | |
public int getTimeoutSeconds() { | |
return DEFAULT_TIMEOUT_SECONDS; | |
} | |
private void sendError(String respondTo, Integer id, int code, String message, Throwable cause) { | |
logger.error(String.format("[%s] jsonrpc error %d", respondTo, code), cause); | |
if (respondTo != null) { | |
try { | |
ObjectNode response = objectMapper.createObjectNode(); | |
response.put(JSONRPC, "2.0"); | |
if (id != null) { | |
response.put(ID, id); | |
} | |
ObjectNode error = objectMapper.createObjectNode(); | |
error.put(CODE, code); | |
error.put(MESSAGE, message); | |
if (cause != null) { | |
cause = findRoot(cause); | |
String exceptionMessage = cause.getMessage(); | |
if (exceptionMessage != null) { | |
error.put("cause", exceptionMessage); | |
} | |
StackTraceElement[] stackTrace = cause.getStackTrace(); | |
if (stackTrace != null) { | |
ArrayNode backtraceNode = objectMapper.createArrayNode(); | |
for (StackTraceElement stackTraceElement : stackTrace) { | |
int line = stackTraceElement.getLineNumber(); | |
if (line < 1) { | |
break; | |
} | |
String filename = stackTraceElement.getFileName(); | |
if (filename == null) { | |
break; | |
} | |
String methodName = stackTraceElement.getMethodName(); | |
String className = stackTraceElement.getClassName(); | |
methodName = className + "." + methodName; | |
ObjectNode stackTraceElementNode = objectMapper.createObjectNode(); | |
stackTraceElementNode.put("file", filename); | |
stackTraceElementNode.put("line", line); | |
stackTraceElementNode.put("method", methodName); | |
backtraceNode.add(stackTraceElementNode); | |
} | |
error.set("backtrace", backtraceNode); | |
} | |
} | |
response.set(ERROR, error); | |
messageSender.sendMessage(respondTo, objectMapper.writeValueAsString(response)); | |
} catch (Exception e) { | |
logger.error("could note send message", e); | |
} | |
} | |
} | |
private void doInvoke(int id, String respondTo, Object impl, Method method, ArrayNode params) { | |
Object[] args = new Object[params.size()]; | |
final Type[] parameterTypes = method.getGenericParameterTypes(); | |
try { | |
for (int i = 0, max = Math.min(params.size(), parameterTypes.length); i < max; i++) { | |
JavaType javaType = objectMapper.getTypeFactory().constructType(parameterTypes[i]); | |
args[i] = objectMapper.treeToValue(params.get(i), javaType); | |
} | |
} catch (Exception e) { | |
sendError(respondTo, id, JsonRpcErrorCodes.ERR_INVALID_PARAMS, "Invalid params", e); | |
return; | |
} | |
try { | |
Object result = method.invoke(impl, args); | |
ObjectNode response = objectMapper.createObjectNode(); | |
response.put(JSONRPC, "2.0"); | |
response.put(ID, id); | |
response.set("result", objectMapper.valueToTree(result)); | |
String payload = objectMapper.writer().writeValueAsString(response); | |
logger.debug(String.format("Sending message to %s : %s", respondTo, payload)); | |
messageSender.sendMessage(respondTo, payload); | |
} catch (Throwable e) { | |
sendError(respondTo, id, JsonRpcErrorCodes.ERR_SERVER_ERROR, "invocation failed", e); | |
} | |
} | |
private Object findImpl(String methodSpec) { | |
Matcher m = methodNamePattern.matcher(methodSpec); | |
if (m.matches()) { | |
String serviceName = m.group(1); | |
return implementationProvider.getImpl(serviceName); | |
} | |
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND, "Method not found"); | |
} | |
private Method findMethod(String methodSpec, Object impl, int paramsCount) { | |
final String key = methodSpec + "/" + paramsCount; | |
return methods.computeIfAbsent(key, newKey -> { | |
Matcher m = methodNamePattern.matcher(methodSpec); | |
if (m.matches()) { | |
final String methodName = m.group(2); | |
for (Method candidate : impl.getClass().getMethods()) { | |
if (candidate.getName().equals(methodName) && candidate.getParameterTypes().length == paramsCount) { | |
return candidate; | |
} | |
} | |
} | |
throw new JsonRpcException(JsonRpcErrorCodes.ERR_METHOD_NOT_FOUND, "Method not found"); | |
}); | |
} | |
@Override | |
public void close() { | |
readTask.cancel(true); | |
} | |
@SuppressWarnings({"unchecked", "unused"}) | |
public <T> T createProxy(final String recipient, final String serviceName, Class<T> iface) { | |
if (!iface.isInterface()) { | |
throw new IllegalArgumentException("iface must be an interface"); | |
} | |
final String key = serviceName + "@" + recipient; | |
Object impl; | |
if ((impl = proxies.get(key)) != null) { | |
return (T) impl; | |
} | |
InvocationHandler invocationHandler = new InvocationHandler() { | |
final Object local = new Object() { | |
@Override | |
public String toString() { | |
return "proxy of '" + serviceName + "@" + recipient + "'"; | |
} | |
}; | |
@Override | |
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { | |
Class<?> declaringClass = method.getDeclaringClass(); | |
if (declaringClass.equals(Object.class) || declaringClass.equals(local.getClass())) { | |
return method.invoke(local, args); | |
} | |
int id = counter.incrementAndGet(); | |
ObjectNode request = objectMapper.createObjectNode(); | |
request.put(JSONRPC, "2.0"); | |
request.put(ID, id); | |
request.put(METHOD, serviceName + "." + method.getName()); | |
ArrayNode params = objectMapper.createArrayNode(); | |
for (Object arg : args) { | |
params.add(objectMapper.valueToTree(arg)); | |
} | |
request.set(PARAMS, params); | |
final ResponseHandler responseHandler = new ResponseHandler(method.getReturnType()); | |
responseHandlers.put(id, responseHandler); | |
messageSender.sendMessage(recipient, objectMapper.writeValueAsString(request)); | |
try { | |
return responseHandler.waitFor(); | |
} finally { | |
responseHandlers.remove(id); | |
} | |
} | |
}; | |
impl = Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(), new Class<?>[]{iface}, | |
invocationHandler); | |
proxies.put(key, impl); | |
return (T) impl; | |
} | |
private interface JsonRpcErrorCodes { | |
int ERR_METHOD_NOT_FOUND = -32601; | |
int ERR_INVALID_PARAMS = -32602; | |
int ERR_INTERNAL_ERROR = -32603; | |
int ERR_SERVER_ERROR = -32000; | |
} | |
public interface IncomingMessage { | |
String getSender(); | |
String getPayload(); | |
} | |
/** | |
* contract for implementing how to send messages to a given recipient | |
*/ | |
public interface MessageSender { | |
/** | |
* @param recipient the name of the recipient | |
* @param payload the message to be sent | |
*/ | |
void sendMessage(String recipient, String payload) throws IOException; | |
} | |
public interface MessageProvider { | |
/** | |
* waits for incoming message (blocking) | |
* | |
* @return the next message that we receive | |
*/ | |
IncomingMessage readMessage() throws InterruptedException; | |
} | |
/** | |
* interface that is responsible for providing implementations for the required | |
* service name. | |
*/ | |
public interface ImplementationProvider { | |
Object getImpl(String serviceName); | |
} | |
private static class JsonRpcException extends RuntimeException { | |
private final int code; | |
public JsonRpcException(int code, String message) { | |
super(message); | |
this.code = code; | |
} | |
public int getCode() { | |
return code; | |
} | |
} | |
public static class SimpleIncomingMessage implements IncomingMessage { | |
private final String sender; | |
private final String payload; | |
public SimpleIncomingMessage(String sender, String payload) { | |
this.sender = sender; | |
this.payload = payload; | |
} | |
@Override | |
public String getSender() { | |
return sender; | |
} | |
@Override | |
public String getPayload() { | |
return payload; | |
} | |
} | |
class ResponseHandler { | |
private final Semaphore sem = new Semaphore(0); | |
private final Class<?> returnType; | |
boolean isDone = false; | |
private Object result; | |
private JsonRpcException exception; | |
public ResponseHandler(Class<?> returnType) { | |
this.returnType = returnType; | |
} | |
void onResult(JsonNode result) { | |
if (!this.returnType.equals(Void.TYPE)) { | |
try { | |
this.result = objectMapper.treeToValue(result, returnType); | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
} | |
isDone = true; | |
sem.release(); | |
} | |
void onError(JsonNode result) { | |
this.exception = new JsonRpcException(result.get(CODE).asInt(), result.get(MESSAGE).asText()); | |
isDone = true; | |
sem.release(); | |
} | |
void keepAlive() { | |
sem.release(); | |
} | |
Object waitFor() throws Exception { | |
final int t = getTimeoutSeconds(); | |
boolean waiting = true; | |
while (waiting) { | |
if (sem.tryAcquire(t, TimeUnit.SECONDS)) { | |
if (isDone) { | |
if (this.exception != null) { | |
throw this.exception; | |
} else { | |
return this.result; | |
} | |
} | |
} else { | |
waiting = false; | |
} | |
} | |
throw new TimeoutException(String.format("No response from remote after %d seconds", t)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment