Skip to content

Instantly share code, notes, and snippets.

@rlindooren
Last active August 21, 2019 12:17
Show Gist options
  • Save rlindooren/df53abb3ea047caa7f725b4929dc3341 to your computer and use it in GitHub Desktop.
Save rlindooren/df53abb3ea047caa7f725b4929dc3341 to your computer and use it in GitHub Desktop.
Passing MDC data to VAVR Future threads
import io.vavr.control.Option;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.MDC;
import javax.annotation.Nonnull;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;
/**
* Passes the logging mapped diagnostics context (MDC) to executor threads.
*/
@Slf4j
public class ContextAwareExecutorService implements ExecutorService {
private final ExecutorService executorService;
private static ContextAwareExecutorService INSTANCE;
private ContextAwareExecutorService() {
// Using a cached executor service, the same that VAVR uses by default.
executorService = Executors.newCachedThreadPool();
}
static synchronized ContextAwareExecutorService getInstance() {
if (INSTANCE == null) {
INSTANCE = new ContextAwareExecutorService();
}
return INSTANCE;
}
/**
* Adds context to the executor thread before running the given runnable.
*/
private static class ContextAwareRunnable implements Runnable {
final Runnable runnable;
final Option<Map<String, String>> loggingContextFromOriginalThread;
ContextAwareRunnable(final Runnable runnable) {
this.runnable = Objects.requireNonNull(runnable);
loggingContextFromOriginalThread = Option.of(MDC.getCopyOfContextMap());
}
@Override
public void run() {
final Option<Map<String, String>> initialLoggingContextForExecutorThread =
addAdditionalDataToMdcOfExecutorThread(loggingContextFromOriginalThread);
try {
runnable.run();
} finally {
restoreMdcOfExecutorThread(initialLoggingContextForExecutorThread);
}
}
}
/**
* Adds context to the executor thread before calling the given callable.
*/
private static class ContextAwareCallable<V> implements Callable<V>{
final Callable<V> callable;
final Option<Map<String, String>> loggingContextFromOriginalThread;
ContextAwareCallable(final Callable<V> callable) {
this.callable = Objects.requireNonNull(callable);
loggingContextFromOriginalThread = Option.of(MDC.getCopyOfContextMap());
}
@Override
public V call() throws Exception {
final Option<Map<String, String>> initialLoggingContextForExecutorThread =
addAdditionalDataToMdcOfExecutorThread(loggingContextFromOriginalThread);
try {
return callable.call();
} finally {
restoreMdcOfExecutorThread(initialLoggingContextForExecutorThread);
}
}
}
/**
* @return the previous MDC state for the executor thread
*/
private static Option<Map<String, String>> addAdditionalDataToMdcOfExecutorThread(final Option<Map<String, String>> loggingContextFromOriginalThread) {
// Make a backup of the MDC state of the executor thread before adding additional data
final var initialLoggingContextForExecutorThread = Option.of(MDC.getCopyOfContextMap());
// Add additional MDC data that exists for the original calling thread
loggingContextFromOriginalThread
.forEach(mdcFromOriginalThread -> mdcFromOriginalThread.forEach(MDC::put));
return initialLoggingContextForExecutorThread;
}
private static void restoreMdcOfExecutorThread(final Option<Map<String, String>> initialLoggingContextForExecutorThread) {
MDC.clear();
// Add additional MDC data that existed already for the executor thread
initialLoggingContextForExecutorThread
.forEach(initialMdcForExecutorThread -> initialMdcForExecutorThread.forEach(MDC::put));
}
@Override
public void execute(@Nonnull Runnable command) {
executorService.execute(new ContextAwareRunnable(command));
}
@Override
public void shutdown() {
executorService.shutdown();
}
@Override
@Nonnull
public List<Runnable> shutdownNow() {
return executorService.shutdownNow();
}
@Override
public boolean isShutdown() {
return executorService.isShutdown();
}
@Override
public boolean isTerminated() {
return executorService.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, @Nonnull TimeUnit unit) throws InterruptedException {
return executorService.awaitTermination(timeout, unit);
}
@Override
@Nonnull
public <T> java.util.concurrent.Future<T> submit(@Nonnull Callable<T> task) {
return executorService.submit(new ContextAwareCallable<>(task));
}
@Override
@Nonnull
public <T> java.util.concurrent.Future<T> submit(@Nonnull Runnable task, T result) {
return executorService.submit(new ContextAwareRunnable(task), result);
}
@Override
@Nonnull
public java.util.concurrent.Future<?> submit(@Nonnull Runnable task) {
return executorService.submit(new ContextAwareRunnable(task));
}
@Override
@Nonnull
public <T> List<java.util.concurrent.Future<T>> invokeAll(@Nonnull Collection<? extends Callable<T>> tasks) throws InterruptedException {
return executorService.invokeAll(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()));
}
@Override
@Nonnull
public <T> List<java.util.concurrent.Future<T>> invokeAll(@Nonnull Collection<? extends Callable<T>> tasks, long timeout, @Nonnull TimeUnit unit) throws InterruptedException {
return executorService.invokeAll(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()), timeout, unit);
}
@Override
@Nonnull
public <T> T invokeAny(@Nonnull Collection<? extends Callable<T>> tasks) throws ExecutionException, InterruptedException {
return executorService.invokeAny(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()));
}
@Override
@Nonnull
public <T> T invokeAny(@Nonnull Collection<? extends Callable<T>> tasks, long timeout, @Nonnull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return executorService.invokeAny(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()), timeout, unit);
}
}
import io.vavr.CheckedFunction0;
import io.vavr.concurrent.Future;
import io.vavr.control.Try;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
/**
* Uses the {@link ContextAwareExecutorService} as underlying executor service for futures
*/
public class ContextAwareFuture {
private static final ExecutorService executorService = ContextAwareExecutorService.getInstance();
public static <T> Future<T> fromTry(final Try<? extends T> result) {
return Future.fromTry(executorService, result);
}
public static <T> Future<T> successful(T result) {
return Future.successful(executorService, result);
}
public static <T> Future<T> of(CheckedFunction0<? extends T> computation) {
return Future.of(executorService, computation);
}
public static <T> Future<T> ofCallable(Callable<? extends T> computation) {
return Future.of(executorService, computation::call);
}
public static <T> Future<T> fromCompletableFuture(CompletableFuture<T> future) {
// When using this method the MDC will not be passed on the executor thread
// most likely because future.handle is used instead of future.handleAsync(...executorService...)
//return Future.fromCompletableFuture(executorService, future);
// Using `get` feels counter intuitive as it is blocking
// but it is executed asynchronously on a separate executor thread
// and this approach allows the MDC thread data to be passed on.
return fromTry(Try.of(future::get).recoverWith(error -> Try.failure(error.getCause())));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment