Skip to content

Instantly share code, notes, and snippets.

@bbakerman
Created September 24, 2019 02:51
Show Gist options
  • Save bbakerman/bdf1b91aeda554d8ec260ef4c36965e1 to your computer and use it in GitHub Desktop.
Save bbakerman/bdf1b91aeda554d8ec260ef4c36965e1 to your computer and use it in GitHub Desktop.
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.StarWarsData;
import graphql.StarWarsSchema;
import graphql.schema.DataFetcher;
import graphql.schema.GraphQLSchema;
import graphql.schema.TypeResolver;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
import graphql.schema.idl.SchemaPrinter;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.dataloader.BatchLoader;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderOptions;
import org.dataloader.DataLoaderRegistry;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Deque;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring;
@SuppressWarnings("UnnecessaryLocalVariable")
public class RequestConcurrencyPOC {
static class ConcurrentWorkManager {
class WorkJob {
final PerRequestWorkPermits workPermits;
final Supplier<CompletableFuture<?>> codeToRun;
final CompletableFuture<?> cfToComplete;
public WorkJob(PerRequestWorkPermits workPermits, Supplier<CompletableFuture<?>> codeToRun, CompletableFuture<?> cfToComplete) {
this.workPermits = workPermits;
this.codeToRun = codeToRun;
this.cfToComplete = cfToComplete;
}
}
ExecutorService executorService;
Deque<WorkJob> workJobs;
public ConcurrentWorkManager(ExecutorService executorService) {
this.executorService = executorService;
this.workJobs = new ConcurrentLinkedDeque<>();
}
public <T> CompletableFuture<T> execute(PerRequestWorkPermits requestWorkPermits, Supplier<CompletableFuture<?>> codeToRun) {
CompletableFuture<T> cfToBeDone = new CompletableFuture<>();
this.workJobs.offer(new WorkJob(requestWorkPermits, codeToRun, cfToBeDone));
executorService.submit(serviceQ());
return cfToBeDone;
}
@SuppressWarnings("unchecked")
public Runnable serviceQ() {
return () -> {
WorkJob workJob = workJobs.poll();
if (workJob == null) {
return;
}
boolean proceed = workJob.workPermits.tryAcquire(1);
if (!proceed) {
// we could not get a permit to proceed - that is their per request concurrency was exceeded
// so enqueue the job again and try later
System.out.printf("===== No permits available for %s\n", workJob.workPermits);
workJobs.offer(workJob);
executorService.submit(serviceQ());
}
// they have a permit - run their code which supplies an CF
try {
CompletableFuture<?> resultCF = workJob.codeToRun.get();
resultCF.whenComplete((data, throwable) -> {
CompletableFuture<Object> cfToComplete = (CompletableFuture<Object>) workJob.cfToComplete;
if (throwable != null) {
cfToComplete.completeExceptionally(throwable);
} else {
cfToComplete.complete(data);
}
});
} finally {
workJob.workPermits.release(1);
}
};
}
}
static class PerRequestWorkPermits {
private final String name;
private final Semaphore semaphore;
public PerRequestWorkPermits(String name, int permits) {
this.name = name;
semaphore = new Semaphore(permits, true);
}
public boolean tryAcquire(int permits) {
return semaphore.tryAcquire(permits);
}
public void release(int permits) {
semaphore.release(permits);
}
public int availablePermits() {
return semaphore.availablePermits();
}
@Override
public String toString() {
return name + ":" + semaphore.availablePermits();
}
}
public static void main(String[] args) {
ExecutorService oneHundredThreads = Executors.newFixedThreadPool(100);
ConcurrentWorkManager workManager = new ConcurrentWorkManager(oneHundredThreads);
// the batch loader is async
BatchLoader<String, Object> characterBatchLoader = keys -> CompletableFuture.supplyAsync(() -> {
//
// this batch load is going to simulate some delay -e g work being done
System.out.println("snoozing");
randomSnoozeMs(200, 5000);
List<Object> characters = keys.stream().map(StarWarsData::getCharacter).collect(Collectors.toList());
return characters;
});
// we dont cache to ensure this DataLoader / BatchLoader gets some real work!
DataLoaderOptions noCaching = DataLoaderOptions.newOptions().setCachingEnabled(false);
DataLoader<String, Object> characterDL = DataLoader.newDataLoader(characterBatchLoader, noCaching);
DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry();
dataLoaderRegistry.register("characters", characterDL);
DataFetcher humanDF = env -> {
PerRequestWorkPermits perRequestWorkPermits = env.getContext();
CompletableFuture<Object> cf = workManager.execute(perRequestWorkPermits, () -> {
DataLoader<String, Object> dl = env.getDataLoader("characters");
return dl.load(env.getArgument("id"));
});
return cf;
};
TypeResolver characterTR = env -> env.getSchema().getObjectType("Human");
RuntimeWiring runtimeWiring = RuntimeWiring.newRuntimeWiring()
.type(newTypeWiring("QueryType").dataFetcher("human", humanDF))
.type(newTypeWiring("Character").typeResolver(characterTR))
.build();
String sdl = new SchemaPrinter().print(StarWarsSchema.starWarsSchema);
TypeDefinitionRegistry typeRegistry = new SchemaParser().parse(sdl);
GraphQLSchema graphQLSchema = new SchemaGenerator().makeExecutableSchema(typeRegistry, runtimeWiring);
GraphQL graphQL = GraphQL.newGraphQL(graphQLSchema).build();
//
// now execute
for (int i = 0; i < 100; i++) {
int clientNum = i;
PerRequestWorkPermits perRequestWorkPermits = new PerRequestWorkPermits("client" + clientNum, 3);
String query = "{ human(id:\"1000\") { name } }"; // luke
ExecutionInput ei = ExecutionInput.newExecutionInput()
.context(perRequestWorkPermits)
.dataLoaderRegistry(dataLoaderRegistry)
.query(query)
.build();
CompletableFuture<ExecutionResult> cfResult = graphQL.executeAsync(ei);
cfResult.whenComplete((data, throwable) -> {
System.out.println(String.format("Completed %d - data %s - errors %s", clientNum, data.getData(), data.getErrors()));
});
}
}
public static void randomSnoozeMs(int minMs, int maxMs) {
Duration duration = Duration.of(getRandomNumberInRange(minMs, maxMs), ChronoUnit.MILLIS);
try {
Thread.sleep(duration.toMillis());
} catch (InterruptedException e) {
e.printStackTrace();
}
}
private static int getRandomNumberInRange(int min, int max) {
if (min >= max) {
throw new IllegalArgumentException("max must be greater than min");
}
Random r = new Random();
return r.nextInt((max - min) + 1) + min;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment