Skip to content

Instantly share code, notes, and snippets.

@ryantenney
Created February 9, 2015 18:25
Show Gist options
  • Save ryantenney/b01282970735c6364851 to your computer and use it in GitHub Desktop.
Save ryantenney/b01282970735c6364851 to your computer and use it in GitHub Desktop.
package com.enernoc.cost.common.util.concurrent;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
public final class FairExecutorService extends ThreadPoolExecutor {
private final int corePoolSize;
private final int maxPoolSize;
private final int leasesPerThread;
private final long keepAliveTime;
private final TimeUnit timeUnit;
private final BlockingQueue<Runnable> workQueue;
private final ThreadFactory threadFactory;
private final RejectedExecutionHandler rejectedExecutionHandler;
private final ThreadLocal<Semaphore> semaphoreThreadLocal;
public static Builder builder() {
return new Builder();
}
private FairExecutorService(
int corePoolSize,
int maxPoolSize,
int leasesPerThread,
long keepAliveTime,
TimeUnit timeUnit,
BlockingQueue<Runnable> workQueue,
ThreadFactory threadFactory,
RejectedExecutionHandler rejectedExecutionHandler) {
super(corePoolSize, maxPoolSize, keepAliveTime, timeUnit, workQueue, threadFactory, rejectedExecutionHandler);
this.corePoolSize = corePoolSize;
this.maxPoolSize = maxPoolSize;
this.leasesPerThread = leasesPerThread;
this.keepAliveTime = keepAliveTime;
this.timeUnit = timeUnit;
this.workQueue = workQueue;
this.threadFactory = threadFactory;
this.rejectedExecutionHandler = rejectedExecutionHandler;
this.semaphoreThreadLocal = ThreadLocal.withInitial(() -> new Semaphore(leasesPerThread, true));
init();
}
public void init() {
prestartAllCoreThreads();
}
@Override
protected <T> RunnableFuture<T> newTaskFor(Callable<T> callable) {
Semaphore semaphore = semaphoreThreadLocal.get();
SemaphorePermit permit = new SemaphorePermit(semaphore);
FutureTaskWithSemaphorePermit<T> futureTask = new FutureTaskWithSemaphorePermit<T>(callable, permit);
try {
semaphore.acquire();
return futureTask;
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
throw new RuntimeException(ex);
}
}
@Override
protected <T> RunnableFuture<T> newTaskFor(Runnable runnable, T value) {
return this.newTaskFor(Executors.callable(runnable, value));
}
@Override
public void execute(Runnable command) {
if (command instanceof FutureTaskWithSemaphorePermit) {
super.execute(command);
}
else {
super.execute(newTaskFor(command, null));
}
}
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
}
private static final class SemaphorePermit {
private final Semaphore semaphore;
private final AtomicInteger permitsAcquired;
public SemaphorePermit(Semaphore semaphore) {
this(semaphore, 1);
}
public SemaphorePermit(Semaphore semaphore, int permitsAcquired) {
this.semaphore = semaphore;
this.permitsAcquired = new AtomicInteger(permitsAcquired);
}
public void release() {
for (;;) {
int permits = permitsAcquired.get();
if (permits == 0) {
return;
}
if (permitsAcquired.compareAndSet(permits, 0)) {
semaphore.release(permits);
return;
}
}
}
}
private static final class FutureTaskWithSemaphorePermit<T> extends FutureTask<T> {
private final SemaphorePermit semaphorePermit;
public FutureTaskWithSemaphorePermit(Callable<T> callable, SemaphorePermit semaphorePermit) {
super(callable);
this.semaphorePermit = semaphorePermit;
}
@Override
protected void done() {
semaphorePermit.release();
}
}
public static final class Builder {
private int corePoolSize = -1;
private int maxPoolSize = -1;
private int leasesPerThread = -1;
private BlockingQueue<Runnable> workQueue;
private ThreadFactory threadFactory = Executors.defaultThreadFactory();
private long keepAliveTime = Long.MAX_VALUE;
private TimeUnit timeUnit = TimeUnit.NANOSECONDS;
private boolean allowCoreThreadTimeOut = false;
private RejectedExecutionHandler rejectedExecutionHandler = new CallerRunsPolicy();
private final AtomicBoolean used = new AtomicBoolean(false);
private Builder() {
}
public void setCorePoolSize(int corePoolSize) {
this.corePoolSize = corePoolSize;
if (this.maxPoolSize == -1) {
this.maxPoolSize = corePoolSize;
}
}
public Builder withCorePoolSize(int corePoolSize) {
setCorePoolSize(corePoolSize);
return this;
}
public void setMaxPoolSize(int maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
public Builder withMaxPoolSize(int maxPoolSize) {
setMaxPoolSize(maxPoolSize);
return this;
}
public void setLeasesPerThread(int leasesPerThread) {
this.leasesPerThread = leasesPerThread;
}
public Builder withLeasesPerThread(int leasesPerThread) {
setLeasesPerThread(leasesPerThread);
return this;
}
public void setKeepAliveTime(int keepAliveTime) {
this.keepAliveTime = keepAliveTime;
}
public void setKeepAliveTime(int keepAliveTime, TimeUnit unit) {
this.keepAliveTime = keepAliveTime;
this.timeUnit = unit;
}
public Builder withKeepAliveTime(int keepAliveTime) {
setKeepAliveTime(keepAliveTime);
return this;
}
public Builder withKeepAliveTime(int keepAliveTime, TimeUnit unit) {
setKeepAliveTime(keepAliveTime, unit);
return this;
}
public void setTimeUnit(TimeUnit timeUnit) {
this.timeUnit = timeUnit;
}
public Builder withTimeUnit(TimeUnit timeUnit) {
this.timeUnit = timeUnit;
return this;
}
public void setAllowCoreThreadTimeOut(boolean allowCoreThreadTimeOut) {
this.allowCoreThreadTimeOut = allowCoreThreadTimeOut;
}
public Builder withAllowCoreThreadTimeOut(boolean allowCoreThreadTimeOut) {
setAllowCoreThreadTimeOut(allowCoreThreadTimeOut);
return this;
}
public void setWorkQueue(BlockingQueue<Runnable> workQueue) {
this.workQueue = workQueue;
}
public Builder withWorkQueue(BlockingQueue<Runnable> workQueue) {
setWorkQueue(workQueue);
return this;
}
public void setThreadFactory(ThreadFactory threadFactory) {
this.threadFactory = threadFactory;
}
public Builder withThreadFactory(ThreadFactory threadFactory) {
setThreadFactory(threadFactory);
return this;
}
public void setRejectedExecutionHandler(RejectedExecutionHandler rejectedExecutionHandler) {
this.rejectedExecutionHandler = rejectedExecutionHandler;
}
public Builder withRejectedExecutionHandler(RejectedExecutionHandler rejectedExecutionHandler) {
setRejectedExecutionHandler(rejectedExecutionHandler);
return this;
}
public FairExecutorService build() {
if (leasesPerThread <= 0) {
throw new IllegalArgumentException("leasesPerThread");
}
if (!used.compareAndSet(false, true)) {
throw new IllegalStateException("used");
}
if (workQueue == null) {
workQueue = new LinkedBlockingQueue<>();
}
FairExecutorService fairExecutorService = new FairExecutorService(corePoolSize, maxPoolSize, leasesPerThread, keepAliveTime, timeUnit, workQueue,
threadFactory, rejectedExecutionHandler);
fairExecutorService.allowCoreThreadTimeOut(allowCoreThreadTimeOut);
return fairExecutorService;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment