Created
January 31, 2022 13:12
-
-
Save ashleyfrieze/5d8cf082fe0dc3b8b6374981111cae3b to your computer and use it in GitHub Desktop.
A concurrency helper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package uk.org.webcompere.concurrency; | |
import org.springframework.beans.factory.annotation.Value; | |
import org.springframework.stereotype.Service; | |
import javax.annotation.PreDestroy; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.concurrent.*; | |
import java.util.stream.Stream; | |
import static java.util.stream.Collectors.toList; | |
/** | |
* A concurrency helper which farms out additional work to spare worker threads, but uses the calling thread too - | |
* conserves the total number of active threads, and allows occasional use of spare threads | |
*/ | |
@Service | |
public class ExtraThreadsService { | |
private ExecutorService executorService; | |
/** | |
* For functions that return nothing | |
*/ | |
@FunctionalInterface | |
public interface ThrowingRunnable { | |
void run() throws Exception; | |
default Callable<Void> asCallable() { | |
return () -> { | |
this.run(); | |
return null; | |
}; | |
} | |
} | |
public ExtraThreadsService(@Value("${concurrency.extra.threads:2}") int numExtraThreads) { | |
if (numExtraThreads < 1) { | |
throw new IllegalArgumentException("Cannot have fewer than 1 additional thread"); | |
} | |
executorService = Executors.newFixedThreadPool(numExtraThreads); | |
} | |
@PreDestroy | |
public void preDestroy() throws InterruptedException { | |
executorService.shutdown(); | |
executorService.awaitTermination(30, TimeUnit.SECONDS); | |
} | |
/** | |
* Run a series of runnables | |
* @param runnables the runnables to execute using spare threads | |
*/ | |
public void runAll(ThrowingRunnable ... runnables) throws Exception { | |
runAll(Arrays.stream(runnables) | |
.map(ThrowingRunnable::asCallable)); | |
} | |
/** | |
* Run a series of callables | |
* @param callables the callables to run | |
* @param <T> the return type | |
* @return the result of the execution | |
* @throws Exception if any of them throw | |
*/ | |
@SafeVarargs | |
public final <T> List<T> runAll(Callable<T> ... callables) throws Exception { | |
return runAll(Arrays.stream(callables)); | |
} | |
/** | |
* Run a stream of callables | |
* @param callables the callables to run | |
* @param <T> the return type | |
* @return the result of the execution | |
* @throws Exception if any of them throw | |
*/ | |
public <T> List<T> runAll(Stream<Callable<T>> callables) throws Exception { | |
// find the work we need to do | |
var toRun = callables.collect(toList()); | |
if (toRun.isEmpty()) { | |
return List.of(); | |
} | |
List<Future<T>> asyncCallsForRest = scheduleBackgroundTasks(toRun); | |
List<T> result = runFirstTask(toRun); | |
collectBackgroundTasks(asyncCallsForRest, result); | |
return result; | |
} | |
/** | |
* Skip the first task and schedule the rest | |
* @param toRun the tasks to run | |
* @param <T> the type of return value | |
* @return a list of futures | |
*/ | |
private <T> List<Future<T>> scheduleBackgroundTasks(List<Callable<T>> toRun) { | |
return toRun.stream().skip(1) | |
.map(toCall -> executorService.submit(toCall::call)) | |
.collect(toList()); | |
} | |
private <T> List<T> runFirstTask(List<Callable<T>> toRun) throws Exception { | |
List<T> result = new ArrayList<>(); | |
result.add(toRun.get(0).call()); | |
return result; | |
} | |
private <T> void collectBackgroundTasks(List<Future<T>> asyncCallsForRest, List<T> result) throws InterruptedException, ExecutionException { | |
for (Future<T> asyncCall : asyncCallsForRest) { | |
result.add(asyncCall.get()); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package uk.org.webcompere.concurrency; | |
import org.junit.jupiter.api.AfterEach; | |
import org.junit.jupiter.api.Test; | |
import java.util.Map; | |
import java.util.concurrent.ConcurrentHashMap; | |
import java.util.stream.Stream; | |
import static org.assertj.core.api.Assertions.assertThat; | |
import static org.assertj.core.api.Assertions.assertThatThrownBy; | |
class ExtraThreadsServiceTest { | |
private ExtraThreadsService extraThreadsService = new ExtraThreadsService(3); | |
@AfterEach | |
void afterEach() throws Exception { | |
extraThreadsService.preDestroy(); | |
} | |
@Test | |
void cannotCreateOneWithZeroThreads() { | |
assertThatThrownBy(() -> new ExtraThreadsService(0)) | |
.isInstanceOf(IllegalArgumentException.class); | |
} | |
@Test | |
void submittingNoWorkDoesNothing() throws Exception { | |
extraThreadsService.runAll(Stream.empty()); | |
} | |
@Test | |
void submittingSingleCallableReturnsItsResult() throws Exception { | |
var results = extraThreadsService.runAll(() -> 123); | |
assertThat(results).containsExactly(123); | |
} | |
@Test | |
void submittingMultipleCallablesReturnsTheirResultsInOrder() throws Exception { | |
var results = extraThreadsService.runAll(() -> 1, () -> 2, () -> 3); | |
assertThat(results).containsExactly(1, 2, 3); | |
} | |
@Test | |
void submittingMultipleJobsExecutesThem() throws Exception { | |
Map<String, String> map = new ConcurrentHashMap<>(); | |
extraThreadsService.runAll( | |
() -> { map.put("a", "b"); }, | |
() -> { map.put("c", "d"); }, | |
() -> { map.put("e", "f"); } ); | |
assertThat(map).hasSize(3); | |
} | |
@Test | |
void anyFailingJobFailsTheTasks() throws Exception { | |
Map<String, String> map = new ConcurrentHashMap<>(); | |
assertThatThrownBy(() -> extraThreadsService.runAll( | |
() -> { map.put("a", "b"); }, | |
() -> { throw new Exception("Boom!"); }, | |
() -> { map.put("e", "f"); } )) | |
.hasMessageContaining("Boom!"); | |
assertThatThrownBy(() -> extraThreadsService.runAll( | |
() -> { throw new Exception("Boom!"); }, | |
() -> { map.put("a", "b"); }, | |
() -> { map.put("e", "f"); } )) | |
.hasMessageContaining("Boom!"); | |
assertThatThrownBy(() -> extraThreadsService.runAll( | |
() -> { map.put("a", "b"); }, | |
() -> { map.put("e", "f"); }, | |
() -> { throw new Exception("Boom!"); } )) | |
.hasMessageContaining("Boom!"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment