Created
July 16, 2017 20:54
-
-
Save amanteaux/64c54a913c1ae34ad7b86db109cbc0bf to your computer and use it in GitHub Desktop.
An ExecutorService that interrupts tasks after a timeout
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Queue; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.Future; | |
import java.util.concurrent.LinkedBlockingQueue; | |
import java.util.concurrent.ScheduledExecutorService; | |
import java.util.concurrent.TimeUnit; | |
public class TimeoutTaskThreadPoolExecutor { | |
private final Queue<Task> awaitingTasks; | |
private final ScheduledExecutorService executor; | |
private final int corePoolSize; | |
private volatile int poolSize; | |
/** | |
* Creates a new {@code TimeoutTaskThreadPoolExecutor} with the | |
* given core pool size. | |
* The pool should be greater or equals than 2 because one thread is reserved | |
* to schedule cancellation task. | |
* | |
* @param corePoolSize the number of threads to keep in the pool, even | |
* if they are idle, unless {@code allowCoreThreadTimeOut} is set | |
* @throws IllegalArgumentException if {@code corePoolSize < 0} | |
*/ | |
public TimeoutTaskThreadPoolExecutor(int corePoolSize) { | |
this.awaitingTasks = new LinkedBlockingQueue<>(); | |
this.executor = Executors.newScheduledThreadPool(corePoolSize); | |
this.corePoolSize = corePoolSize; | |
this.poolSize = 0; | |
} | |
public void execute(Runnable task, long delayTimeout, TimeUnit unit) { | |
awaitingTasks.offer(new Task(task, delayTimeout, unit)); | |
executeWaitingTask(); | |
} | |
public synchronized void shutdown() { | |
executor.shutdown(); | |
awaitingTasks.clear(); | |
} | |
public boolean isTerminated() { | |
return executor.isTerminated(); | |
} | |
private synchronized void executeWaitingTask() { | |
if (executor.isShutdown()) { | |
return; | |
} | |
if ((corePoolSize-poolSize) > 1) { | |
final Task nextTask = awaitingTasks.poll(); | |
if (nextTask != null) { | |
poolSize++; | |
final Future<?> taskHandler = executor.submit(new Runnable() { | |
@Override | |
public void run() { | |
try { | |
nextTask.task.run(); | |
} finally { | |
poolSize--; | |
executeWaitingTask(); | |
} | |
} | |
}); | |
executor.schedule( | |
new Runnable() { | |
@Override | |
public void run() { | |
taskHandler.cancel(true); | |
} | |
}, | |
nextTask.delayTimeout, | |
nextTask.unit | |
); | |
} | |
} | |
} | |
private static class Task { | |
Runnable task; | |
long delayTimeout; | |
TimeUnit unit; | |
public Task(Runnable task, long delayTimeout, TimeUnit unit) { | |
this.task = task; | |
this.delayTimeout = delayTimeout; | |
this.unit = unit; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static org.assertj.core.api.Assertions.assertThat; | |
import java.util.concurrent.Callable; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicBoolean; | |
import org.junit.Test; | |
public class TimeoutTaskThreadPoolExecutorTest { | |
@Test | |
public void checkThatQuickTaskFullyExecutes() throws InterruptedException { | |
TimeoutTaskThreadPoolExecutor executor = new TimeoutTaskThreadPoolExecutor(2); | |
final AtomicBoolean isDone = new AtomicBoolean(false); | |
final Runnable job = new Runnable() { | |
@Override | |
public void run() { | |
isDone.set(true); | |
notifyAll(); | |
} | |
}; | |
executor.execute(job, 100, TimeUnit.MILLISECONDS); | |
Thread thread = waitingThread(job, isDone, 1000); | |
thread.start(); | |
thread.join(0); | |
assertThat(isDone.get()).isTrue(); | |
executor.shutdown(); | |
} | |
@Test | |
public void checkThatTooLongTaskAreActuallyCancelled() throws InterruptedException { | |
TimeoutTaskThreadPoolExecutor executor = new TimeoutTaskThreadPoolExecutor(2); | |
final AtomicBoolean isDone = new AtomicBoolean(false); | |
final AtomicBoolean isExecuted = new AtomicBoolean(false); | |
final Runnable job = new Runnable() { | |
@Override | |
public void run() { | |
try { | |
Thread.sleep(10000); | |
isExecuted.set(true); | |
} catch (InterruptedException e) { | |
// as expected | |
} | |
isDone.set(true); | |
notifyAll(); | |
} | |
}; | |
executor.execute(job, 10, TimeUnit.MILLISECONDS); | |
Thread thread = waitingThread(job, isDone, 1000); | |
thread.start(); | |
thread.join(0); | |
assertThat(isDone.get()).isTrue(); | |
assertThat(isExecuted.get()).isFalse(); | |
executor.shutdown(); | |
} | |
@Test | |
public void multiTasksScenario() throws InterruptedException { | |
// in this scenario what should append: | |
// - job1 is created and executed in the pool in 10ms (timeout 1000ms) | |
// - job2 is created and queued in the pool (since it is full: one thread on job1 & one thread for timeout handling) | |
// - job3 is created and queued in the pool | |
// - job1 finished executing | |
// - job2 execution starts in the pool, it will takes up 10000ms to execute (timeout 100ms) | |
// - job2 is interrupted due to the timeout | |
// - job3 execution starts in the pool, it should last bellow 1ms to execute (timeout 50ms) | |
// - job3 finished executing | |
TimeoutTaskThreadPoolExecutor executor = new TimeoutTaskThreadPoolExecutor(2); | |
final AtomicBoolean isDone1 = new AtomicBoolean(false); | |
final AtomicBoolean isError1 = new AtomicBoolean(false); | |
final AtomicBoolean isDone2 = new AtomicBoolean(false); | |
final AtomicBoolean isExecuted2 = new AtomicBoolean(false); | |
final AtomicBoolean isDone3 = new AtomicBoolean(false); | |
// job 1 should be fully executed though it takes 10ms to be executed | |
final Runnable job1 = new Runnable() { | |
@Override | |
public void run() { | |
try { | |
Thread.sleep(10); | |
isDone1.set(true); | |
notifyAll(); | |
} catch (InterruptedException e) { | |
isError1.set(true); | |
} | |
} | |
}; | |
executor.execute(job1, 1000, TimeUnit.MILLISECONDS); | |
// job 2 should be interrupted | |
final Runnable job2 = new Runnable() { | |
@Override | |
public void run() { | |
try { | |
Thread.sleep(10000); | |
isExecuted2.set(true); | |
} catch (InterruptedException e) { | |
// as expected | |
} | |
isDone2.set(true); | |
notifyAll(); | |
} | |
}; | |
executor.execute(job2, 100, TimeUnit.MILLISECONDS); | |
// job 3 should be fully executed | |
final Runnable job3 = new Runnable() { | |
@Override | |
public void run() { | |
isDone3.set(true); | |
notifyAll(); | |
} | |
}; | |
executor.execute(job3, 50, TimeUnit.MILLISECONDS); | |
Thread thread1 = waitingThread(job1, isDone1, 1000); | |
thread1.start(); | |
Thread thread2 = waitingThread(job2, isDone2, 1000); | |
thread2.start(); | |
Thread thread3 = waitingThread(job3, isDone3, 1000); | |
thread3.start(); | |
thread1.join(0); | |
thread2.join(0); | |
thread3.join(0); | |
assertThat(isDone1.get()).isTrue(); | |
assertThat(isError1.get()).isFalse(); | |
assertThat(isDone2.get()).isTrue(); | |
assertThat(isExecuted2.get()).isFalse(); | |
assertThat(isDone3.get()).isTrue(); | |
executor.shutdown(); | |
} | |
// utils | |
private static Thread waitingThread(final Object lockOn, final AtomicBoolean condition, final long maxWait) { | |
return new Thread(new Runnable() { | |
@Override | |
public void run() { | |
waitOn(lockOn, new Callable<Boolean>() { | |
@Override | |
public Boolean call() { | |
return condition.get(); | |
} | |
}, maxWait); | |
} | |
}); | |
} | |
private static void waitOn(Object lockOn, Callable<Boolean> condition, long maxWait) { | |
long currentTime = System.currentTimeMillis(); | |
long waitUntil = currentTime + maxWait; | |
try { | |
while(!condition.call() && waitUntil > currentTime) { | |
synchronized (lockOn) { | |
try { | |
lockOn.wait(5); | |
} catch (InterruptedException e) { | |
} | |
} | |
currentTime = System.currentTimeMillis(); | |
} | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment