Created
July 7, 2022 15:51
-
-
Save dhet/f8a66d5b1000f667312a59c0862caf2c to your computer and use it in GitHub Desktop.
A debouncer written in Java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lombok.RequiredArgsConstructor; | |
import lombok.val; | |
import java.time.Duration; | |
import java.time.Instant; | |
import java.util.concurrent.*; | |
import static java.util.concurrent.TimeUnit.MILLISECONDS; | |
/** | |
* A utility for debouncing function calls. | |
*/ | |
public class Debouncer<T> { | |
private final ScheduledExecutorService executorService; | |
private final ConcurrentHashMap<Object, DebouncedTask<T>> buffer = new ConcurrentHashMap<>(); | |
private final Duration debounceDuration; | |
public Debouncer(ScheduledExecutorService executorService, Duration debounceDuration) { | |
this.executorService = executorService; | |
this.debounceDuration = debounceDuration; | |
} | |
/** | |
* Submit a task to the Debouncer. A task is identified by the {@code identifier} parameter's | |
* {@link Object#equals(Object)} method. If a second task with an equal identifier is submitted before | |
* {@code debounceDuration} is over, the previous task will be canceled and replaced by the new one. Effectively, | |
* the last callback "wins". As a consequence of this, it cannot be guaranteed that a submitted task will be | |
* executed. | |
* | |
* @param identifier the identifier of the task by which duplication is detected. Identity is defined by the | |
* object's {@code equals(Object)} method. | |
* @param task the task to debounce | |
* @return A future which is completed when the task was executed | |
*/ | |
public Future<?> submit(Object identifier, Runnable task) { | |
return submit(identifier, () -> { | |
task.run(); | |
return null; | |
}); | |
} | |
/** | |
* See {@link Debouncer#submit(Object, Runnable)}. | |
*/ | |
public Future<T> submit(Object identifier, Callable<T> task) { | |
return buffer.compute(identifier, (k, debouncedTask) -> { | |
if (debouncedTask != null) { | |
debouncedTask.scheduledFuture.cancel(false); | |
} else { | |
val deadline = Instant.now().plus(debounceDuration); | |
debouncedTask = new DebouncedTask<>(deadline); | |
} | |
val remainingMillis = millisUntil(debouncedTask.deadline); | |
debouncedTask.scheduledFuture = executorService.schedule(() -> runAndCleanUp(identifier, task), // | |
remainingMillis, MILLISECONDS); | |
return debouncedTask; | |
}).finalFuture; | |
} | |
private void runAndCleanUp(Object key, Callable<T> task) { | |
val debouncedTask = buffer.remove(key); | |
if (debouncedTask == null) { | |
throw new IllegalStateException("Tried to run a debounce task that doesn't exist. This is a bug."); | |
} | |
try { | |
debouncedTask.finalFuture.complete(task.call()); | |
} catch (Exception e) { | |
debouncedTask.finalFuture.completeExceptionally(e); | |
} | |
} | |
private long millisUntil(Instant target) { | |
return Math.max(0, Duration.between(Instant.now(), target).toMillis()); | |
} | |
@RequiredArgsConstructor | |
private static class DebouncedTask<T> { | |
private final CompletableFuture<T> finalFuture = new CompletableFuture<>(); | |
private final Instant deadline; | |
private Future<?> scheduledFuture; | |
} | |
// Package private getter for testing | |
ConcurrentHashMap<Object, DebouncedTask<T>> getBuffer() { | |
return buffer; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lombok.SneakyThrows; | |
import lombok.val; | |
import org.junit.Test; | |
import org.junit.jupiter.api.function.Executable; | |
import java.time.Duration; | |
import java.util.concurrent.ExecutionException; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.Future; | |
import static org.junit.Assert.assertEquals; | |
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; | |
import static org.junit.jupiter.api.Assertions.assertThrows; | |
import static org.mockito.Mockito.*; | |
public class DebouncerTest { | |
@Test | |
public void debouncerShouldPartitionTasksByIdentifier() { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(10); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
val mockRunnable1 = mock(Runnable.class); | |
val mockRunnable2 = mock(Runnable.class); | |
// WHEN | |
debouncer.submit("1", mockRunnable1); | |
debouncer.submit("1", mockRunnable1); | |
debouncer.submit("2", mockRunnable2); | |
debouncer.submit("1", mockRunnable1); | |
debouncer.submit("2", mockRunnable2); | |
// THEN | |
verify(mockRunnable1, after(100).times(1)).run(); | |
verify(mockRunnable2, after(100).times(1)).run(); | |
} | |
@Test | |
public void debouncerShouldOverrideExistingTasks() { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(10); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
val mockRunnable1 = mock(Runnable.class); | |
val mockRunnable2 = mock(Runnable.class); | |
// WHEN | |
debouncer.submit("1", mockRunnable1); | |
debouncer.submit("1", mockRunnable2); | |
// THEN | |
assertEquals(1, debouncer.getBuffer().size()); | |
verify(mockRunnable1, after(100).never()).run(); | |
verify(mockRunnable2, after(100).times(1)).run(); | |
} | |
@Test | |
public void debouncerShouldNotDelayTaskLongerThanDebouncingDuration() throws InterruptedException { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(100); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
val mockRunnable = mock(Runnable.class); | |
// WHEN | |
// Simulating 3 evenly spread calls over a duration of 120 ms with debounce duration of 100 ms | |
// should yield 2 executions. | |
debouncer.submit("1", mockRunnable); // 0 ms | |
Thread.sleep(60); | |
debouncer.submit("1", mockRunnable); // 60 ms | |
// First execution after 100 ms | |
Thread.sleep(60); | |
debouncer.submit("1", mockRunnable); // 120 ms | |
// Second execution after 200 ms | |
// THEN | |
verify(mockRunnable, after(400).times(2)).run(); | |
} | |
@Test | |
public void debouncerShouldCleanUpInternalStateWhenTasksHaveFinished() { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(10); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
val mockRunnable = mock(Runnable.class); | |
// WHEN | |
debouncer.submit("1", mockRunnable); | |
// THEN | |
assertEquals(1, debouncer.getBuffer().size()); | |
verify(mockRunnable, after(100).times(1)).run(); | |
assertEquals(0, debouncer.getBuffer().size()); | |
} | |
@Test | |
public void debouncerShouldCleanUpInternalStateWhenTaskThrows() { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(10); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
Runnable exceptionalRunnable = () -> { | |
throw new RuntimeException("MOCKED"); | |
}; | |
// WHEN | |
Future<?> result = debouncer.submit("1", exceptionalRunnable); | |
// THEN | |
assertEquals(1, debouncer.getBuffer().size()); | |
val executionException = assertThrows(ExecutionException.class, result::get); | |
assertEquals("MOCKED", executionException.getCause().getMessage()); | |
assertEquals(0, debouncer.getBuffer().size()); | |
} | |
@Test | |
@SneakyThrows | |
public void debouncerShouldCompleteWithTheLastResult() { | |
// GIVEN | |
val debounceDuration = Duration.ofMillis(10); | |
val debouncer = new Debouncer<String>(Executors.newSingleThreadScheduledExecutor(), debounceDuration); | |
// WHEN | |
val result1 = debouncer.submit("", () -> "1"); | |
val result2 = debouncer.submit("", () -> "2"); | |
val result3 = debouncer.submit("", () -> "3"); | |
// THEN | |
assertEquals("3", result1.get()); | |
assertEquals("3", result2.get()); | |
assertEquals("3", result3.get()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment