Last active
August 21, 2019 12:17
-
-
Save rlindooren/df53abb3ea047caa7f725b4929dc3341 to your computer and use it in GitHub Desktop.
Passing MDC data to VAVR Future threads
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.vavr.control.Option; | |
import lombok.extern.slf4j.Slf4j; | |
import org.slf4j.MDC; | |
import javax.annotation.Nonnull; | |
import java.util.*; | |
import java.util.concurrent.*; | |
import java.util.stream.Collectors; | |
/** | |
* Passes the logging mapped diagnostics context (MDC) to executor threads. | |
*/ | |
@Slf4j | |
public class ContextAwareExecutorService implements ExecutorService { | |
private final ExecutorService executorService; | |
private static ContextAwareExecutorService INSTANCE; | |
private ContextAwareExecutorService() { | |
// Using a cached executor service, the same that VAVR uses by default. | |
executorService = Executors.newCachedThreadPool(); | |
} | |
static synchronized ContextAwareExecutorService getInstance() { | |
if (INSTANCE == null) { | |
INSTANCE = new ContextAwareExecutorService(); | |
} | |
return INSTANCE; | |
} | |
/** | |
* Adds context to the executor thread before running the given runnable. | |
*/ | |
private static class ContextAwareRunnable implements Runnable { | |
final Runnable runnable; | |
final Option<Map<String, String>> loggingContextFromOriginalThread; | |
ContextAwareRunnable(final Runnable runnable) { | |
this.runnable = Objects.requireNonNull(runnable); | |
loggingContextFromOriginalThread = Option.of(MDC.getCopyOfContextMap()); | |
} | |
@Override | |
public void run() { | |
final Option<Map<String, String>> initialLoggingContextForExecutorThread = | |
addAdditionalDataToMdcOfExecutorThread(loggingContextFromOriginalThread); | |
try { | |
runnable.run(); | |
} finally { | |
restoreMdcOfExecutorThread(initialLoggingContextForExecutorThread); | |
} | |
} | |
} | |
/** | |
* Adds context to the executor thread before calling the given callable. | |
*/ | |
private static class ContextAwareCallable<V> implements Callable<V>{ | |
final Callable<V> callable; | |
final Option<Map<String, String>> loggingContextFromOriginalThread; | |
ContextAwareCallable(final Callable<V> callable) { | |
this.callable = Objects.requireNonNull(callable); | |
loggingContextFromOriginalThread = Option.of(MDC.getCopyOfContextMap()); | |
} | |
@Override | |
public V call() throws Exception { | |
final Option<Map<String, String>> initialLoggingContextForExecutorThread = | |
addAdditionalDataToMdcOfExecutorThread(loggingContextFromOriginalThread); | |
try { | |
return callable.call(); | |
} finally { | |
restoreMdcOfExecutorThread(initialLoggingContextForExecutorThread); | |
} | |
} | |
} | |
/** | |
* @return the previous MDC state for the executor thread | |
*/ | |
private static Option<Map<String, String>> addAdditionalDataToMdcOfExecutorThread(final Option<Map<String, String>> loggingContextFromOriginalThread) { | |
// Make a backup of the MDC state of the executor thread before adding additional data | |
final var initialLoggingContextForExecutorThread = Option.of(MDC.getCopyOfContextMap()); | |
// Add additional MDC data that exists for the original calling thread | |
loggingContextFromOriginalThread | |
.forEach(mdcFromOriginalThread -> mdcFromOriginalThread.forEach(MDC::put)); | |
return initialLoggingContextForExecutorThread; | |
} | |
private static void restoreMdcOfExecutorThread(final Option<Map<String, String>> initialLoggingContextForExecutorThread) { | |
MDC.clear(); | |
// Add additional MDC data that existed already for the executor thread | |
initialLoggingContextForExecutorThread | |
.forEach(initialMdcForExecutorThread -> initialMdcForExecutorThread.forEach(MDC::put)); | |
} | |
@Override | |
public void execute(@Nonnull Runnable command) { | |
executorService.execute(new ContextAwareRunnable(command)); | |
} | |
@Override | |
public void shutdown() { | |
executorService.shutdown(); | |
} | |
@Override | |
@Nonnull | |
public List<Runnable> shutdownNow() { | |
return executorService.shutdownNow(); | |
} | |
@Override | |
public boolean isShutdown() { | |
return executorService.isShutdown(); | |
} | |
@Override | |
public boolean isTerminated() { | |
return executorService.isTerminated(); | |
} | |
@Override | |
public boolean awaitTermination(long timeout, @Nonnull TimeUnit unit) throws InterruptedException { | |
return executorService.awaitTermination(timeout, unit); | |
} | |
@Override | |
@Nonnull | |
public <T> java.util.concurrent.Future<T> submit(@Nonnull Callable<T> task) { | |
return executorService.submit(new ContextAwareCallable<>(task)); | |
} | |
@Override | |
@Nonnull | |
public <T> java.util.concurrent.Future<T> submit(@Nonnull Runnable task, T result) { | |
return executorService.submit(new ContextAwareRunnable(task), result); | |
} | |
@Override | |
@Nonnull | |
public java.util.concurrent.Future<?> submit(@Nonnull Runnable task) { | |
return executorService.submit(new ContextAwareRunnable(task)); | |
} | |
@Override | |
@Nonnull | |
public <T> List<java.util.concurrent.Future<T>> invokeAll(@Nonnull Collection<? extends Callable<T>> tasks) throws InterruptedException { | |
return executorService.invokeAll(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList())); | |
} | |
@Override | |
@Nonnull | |
public <T> List<java.util.concurrent.Future<T>> invokeAll(@Nonnull Collection<? extends Callable<T>> tasks, long timeout, @Nonnull TimeUnit unit) throws InterruptedException { | |
return executorService.invokeAll(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()), timeout, unit); | |
} | |
@Override | |
@Nonnull | |
public <T> T invokeAny(@Nonnull Collection<? extends Callable<T>> tasks) throws ExecutionException, InterruptedException { | |
return executorService.invokeAny(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList())); | |
} | |
@Override | |
@Nonnull | |
public <T> T invokeAny(@Nonnull Collection<? extends Callable<T>> tasks, long timeout, @Nonnull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { | |
return executorService.invokeAny(tasks.stream().map(ContextAwareCallable::new).collect(Collectors.toList()), timeout, unit); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.vavr.CheckedFunction0; | |
import io.vavr.concurrent.Future; | |
import io.vavr.control.Try; | |
import java.util.concurrent.Callable; | |
import java.util.concurrent.CompletableFuture; | |
import java.util.concurrent.ExecutorService; | |
/** | |
* Uses the {@link ContextAwareExecutorService} as underlying executor service for futures | |
*/ | |
public class ContextAwareFuture { | |
private static final ExecutorService executorService = ContextAwareExecutorService.getInstance(); | |
public static <T> Future<T> fromTry(final Try<? extends T> result) { | |
return Future.fromTry(executorService, result); | |
} | |
public static <T> Future<T> successful(T result) { | |
return Future.successful(executorService, result); | |
} | |
public static <T> Future<T> of(CheckedFunction0<? extends T> computation) { | |
return Future.of(executorService, computation); | |
} | |
public static <T> Future<T> ofCallable(Callable<? extends T> computation) { | |
return Future.of(executorService, computation::call); | |
} | |
public static <T> Future<T> fromCompletableFuture(CompletableFuture<T> future) { | |
// When using this method the MDC will not be passed on the executor thread | |
// most likely because future.handle is used instead of future.handleAsync(...executorService...) | |
//return Future.fromCompletableFuture(executorService, future); | |
// Using `get` feels counter intuitive as it is blocking | |
// but it is executed asynchronously on a separate executor thread | |
// and this approach allows the MDC thread data to be passed on. | |
return fromTry(Try.of(future::get).recoverWith(error -> Try.failure(error.getCause()))); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment