Skip to content

Instantly share code, notes, and snippets.

@grignaak
Created September 20, 2011 17:56
Show Gist options
  • Save grignaak/1229793 to your computer and use it in GitHub Desktop.
Save grignaak/1229793 to your computer and use it in GitHub Desktop.
Queues that share a tail, but have independent heads
package hydra;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
/**
* Taken mostly from the documentation of {@link AbstractQueuedSynchronizer}
*/
public class BooleanLatch {
@SuppressWarnings("serial")
private static class Sync extends AbstractQueuedSynchronizer {
boolean isSignalled() {
return getState() != 0;
}
protected int tryAcquireShared(int ignore) {
return isSignalled() ? 1 : -1;
}
protected boolean tryReleaseShared(int ignore) {
setState(1);
return true;
}
}
private final BooleanLatch.Sync sync = new Sync();
public boolean isSignalled() {
return sync.isSignalled();
}
public void signal() {
sync.releaseShared(1);
}
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean await(long duration, TimeUnit unit) throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(duration));
}
}
package hydra;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* Multiple {@link BlockingQueue}s that share a tail.
*
* <p>Adding to one queue adds to all the queues; but taking from one queue does not affect
* the others.
*
* <p>The queue returned by {@link #tail()} is initially empty().
*
* <p>The queue returned by {@link #head()} lags behind the tail by at most
* <code>bufferSize</code> elements at the time {@link #head()} was called. The size of the
* queue may grow beyond <code>bufferSize</code> after it is returned from {@link #head()}.
*
* <p>The queues returned by {@link #head()} and {@link #tail()} are thread safe.
*
* <p>The queues do not accept null values; nor do they support any operation that may remove
* values from another queue. Specifically, these operations are not supported:
* {@link Collection#remove(Object)}, {@link Collection#removeAll(Collection)},
* {@link Collection#retainAll(Collection)}, and {@link Iterator#remove()}
* (as returned by {@link Collection#iterator()}).
* <b>Note:</b> {@link Collection#clear()} <em>is</em> supported because it only
* affects the one queue.
*
* <h3>Implementation Notes</h3>
*
* <p>The queues are merely a shared singly linked list with several <em>head</em> pointers.
*
* <p>To minimize shared locking, the algorithm minimizes the number of shared variables and
* utilizes the atomic reads of references (JLS 17.7). The shared variables are sharedTail
* and sharedTail.next. Atomic reads of sharedTail are marked ATOMIC READ. Atomic reads of
* sharedTail.next are limited to comparing if it is null. These are marked ATOMIC BARRIER
* and are usually spelled head.hasNext or current.hasNext in the code,
*
* <p>Because nodes can be added by other queues, each queue cannot keep a running tally of
* its size. Therefore size is calculated by the distance to the tail. This has the limitation
* that the data structure is limited to having #add() called at most 2^64 times. It is assumed
* you will never reach this. If this assumption is false, make changes to remainingCapacity()
* and throw an exception in offer() when sharedPosition is Long.MAX_VALUE
*
* <p>Furthermore, the assumption is made that the size of each queue is always less than
* Integer.MAX_VALUE. If this assumption cannot be made, make changes to size() and
* remainingCapacity()
*/
public class MultiHeadedQueue<T> {
// monotonically increasing, used to compute current size.
private long tailPosition = Long.MIN_VALUE;
// tail.next always points to the last value
private Node<T> sharedTail = new Node<T>(null, tailPosition);
private Node<T> laggedHead = sharedTail;
private final int bufferSize;
// lock to move the tail, head, and position
private final Lock sharedLock = new ReentrantLock();
/** Create a queue where the head is in sync with the tail */
public MultiHeadedQueue() { this(0); }
/** Create a queue where the head lags behind the tail by <code>bufferSize</code> elements */
public MultiHeadedQueue(int bufferSize) {
if (bufferSize < 0) bufferSize = 0;
this.bufferSize = bufferSize;
}
public BlockingQueue<T> head() {
// If we don't synchronize, the buffer may be > bufferSize for a few operations.
// see comments in #offer()
sharedLock.lock();
try {
return new Queue(laggedHead);
} finally {
sharedLock.unlock();
}
}
public BlockingQueue<T> tail() {
return new Queue(sharedTail); // ATOMIC READ
}
private static class Node<T> {
private final T value;
private final long position;
private final BooleanLatch nextNotifier = new BooleanLatch();
private Node<T> next; // WRITE ONCE
public Node(T value, long position) {
this.value = value;
this.position = position;
}
public void setNext(Node<T> next) {
assert next == null;
this.next = next;
nextNotifier.signal();
}
public boolean hasNext() {
return next != null;
}
public Node<T> next() {
return next;
}
public T value() {
return value;
}
public long distanceTo(Node<T> later) {
return (later.position - this.position);
}
public boolean waitForNext(long duration, TimeUnit unit) throws InterruptedException {
if (hasNext())
return true;
return nextNotifier.await(duration, unit);
}
public void waitForNext() throws InterruptedException {
if (hasNext())
return;
nextNotifier.await();
}
}
private class Queue extends AbstractQueue<T> implements BlockingQueue<T> {
private Node<T> head;
private final Lock headLock = new ReentrantLock();
public Queue(Node<T> head) {
this.head = head;
}
// The only place that mutates shared state
public boolean offer(T e) {
sharedLock.lock();
try {
if (e == null)
throw new NullPointerException();
/* If shared tail is set _after_ oldTail.setNext(), then size() will return
* a number less than the actual size for a few operations. But on the other
* hand, If sharedTail is set _before_ oldTail.setNext(), then size() will
* return a number greater than the actual size for a few operations. The
* first option is the safest and is consistent with other read operations.
*/
Node<T> newTail = new Node<T>(e, ++tailPosition);
sharedTail.setNext(newTail);
sharedTail = newTail;
/* Because this is done _after_ updating the tail, access to the head must
* be synchronized else head will lag behind by more than bufferSize for
* a few operations. Putting this _before_ updating the tail will make the
* head null when bufferSize is 0.
*/
if (laggedHead.distanceTo(sharedTail) > bufferSize)
laggedHead = laggedHead.next();
return true;
} finally {
sharedLock.unlock();
}
}
// there is no wait to add to the queue, hence no InterruptedException
public boolean offer(T e, long timeout, TimeUnit unit) {
return offer(e);
}
// there is no wait to add to the queue, hence no InterruptedException
public void put(T e) {
offer(e);
}
public T peek() {
headLock.lock();
try {
if (!head.hasNext()) // ATOMIC BARRIER
return null;
return head.next().value();
} finally {
headLock.unlock();
}
}
public T poll() {
headLock.lock();
try {
if (!head.hasNext()) // ATOMIC BARRIER
return null;
head = head.next();
return head.value();
} finally {
headLock.unlock();
}
}
public T poll(long duration, TimeUnit unit) throws InterruptedException {
if (!head.waitForNext(duration, unit))
return null;
return remove();
}
public T take() throws InterruptedException {
head.waitForNext();
return remove();
}
public int drainTo(Collection<? super T> c) {
return drainTo(c, Integer.MAX_VALUE);
}
public int drainTo(Collection<? super T> c, int maxElements) {
headLock.lock();
try {
if (c == this) throw new IllegalArgumentException();
int count = 0;
while (head.hasNext() && count < maxElements) { // ATOMIC BARRIER
head = head.next();
c.add(head.value());
count++;
}
return count;
} finally {
headLock.unlock();
}
}
public void clear() {
headLock.lock();
try {
head = sharedTail; // ATOMIC READ
} finally {
headLock.unlock();
}
}
public Iterator<T> iterator() {
return new Iterator<T>() {
private Node<T> current = head;
public boolean hasNext() {
return current.hasNext(); // ATOMIC BARRIER
}
public T next() {
if (!current.hasNext()) // ATOMIC BARRIER
throw new NoSuchElementException();
current = current.next(); // ATOMIC READ
return current.value();
}
public void remove() {
throw new UnsupportedOperationException();
}
};
}
public int size() {
// assuming size < Integer.MAX_VALUE
return (int)head.distanceTo(sharedTail); // ATOMIC READS
}
public int remainingCapacity() {
return Integer.MAX_VALUE;
}
}
}
package hydra;
import static org.junit.Assert.*;
import java.util.*;
import java.util.concurrent.*;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
@RunWith(Enclosed.class)
public class MultiHeadedQueueTest {
private static final int BUFFER_SIZE = 10;
public static class UnsupportedOperations {
private MultiHeadedQueue<Integer> buffer = new MultiHeadedQueue<Integer>(BUFFER_SIZE);
private BlockingQueue<Integer> queue = buffer.tail();
@Test(expected=UnsupportedOperationException.class) public void
removeSpecifiedItem() {
add1through4(queue);
queue.remove(1);
}
@Test(expected=UnsupportedOperationException.class) public void
removeAll() {
add1through4(queue);
queue.removeAll(Arrays.asList(1));
}
@Test(expected=UnsupportedOperationException.class) public void
retainAll() {
add1through4(queue);
queue.retainAll(Arrays.asList(5));
}
@Test(expected=UnsupportedOperationException.class) public void
removeFromIterator() {
queue.add(1);
Iterator<Integer> it = queue.iterator();
it.next();
it.remove();
}
@Test(expected=NullPointerException.class) public void
addNulls() {
queue.add(null);
}
}
public static class Emptiness {
private MultiHeadedQueue<Integer> buffer = new MultiHeadedQueue<Integer>();
private BlockingQueue<Integer> queue = buffer.tail();
@Test public void
shouldHaveZeroSize() {
assertEquals(0, queue.size());
assertTrue(queue.isEmpty());
}
@Test public void
shouldHaveNoValues() throws InterruptedException {
assertNull(queue.peek());
assertNull(queue.poll());
}
@Test(timeout=1000) public void
shouldTimeoutWhenPolling() throws InterruptedException {
long start = System.nanoTime();
Integer value = queue.poll(100, TimeUnit.MILLISECONDS);
long nanoSeconds = System.nanoTime() - start;
assertNull(value);
assertTime(100, TimeUnit.NANOSECONDS.toMillis(nanoSeconds));
}
@Test(expected=NoSuchElementException.class) public void
shouldThrowOnRemove() {
queue.remove();
}
@Test(expected=NoSuchElementException.class) public void
shouldThrowOnElement() {
queue.element();
}
@Test public void
shouldDrainNothing() {
List<Integer> list = new ArrayList<Integer>();
int count = queue.drainTo(list);
assertEquals(0, count);
assertTrue(list.isEmpty());
}
@Test public void
headShouldBeEmpty() {
assertContents(buffer.head());
}
@Test public void
iteratorShouldBeEmpty() {
assertFalse(queue.iterator().hasNext());
}
}
public static class SingleNonEmptyQueue {
private MultiHeadedQueue<Integer> buffer = new MultiHeadedQueue<Integer>(BUFFER_SIZE);
private BlockingQueue<Integer> queue = buffer.tail();
@Test public void
shouldAddAnItem() {
queue.add(1);
assertContents(queue, 1);
}
@Test public void
shouldAddMultipleItem() throws InterruptedException {
queue.add(1);
queue.offer(2);
queue.offer(3, 1, TimeUnit.NANOSECONDS);
queue.put(4);
assertContents(queue, 1, 2, 3, 4);
}
@Test public void
sizeShouldBeUpdated() {
queue.add(1);
queue.add(2);
assertEquals(2, queue.size());
}
@Test public void
shouldPeek() {
queue.add(1);
assertEquals(Integer.valueOf(1), queue.peek());
assertContents(queue, 1);
assertEquals(Integer.valueOf(1), queue.element());
assertContents(queue, 1);
}
@Test public void
shouldRemoveAnItem() {
add1through4(queue);
Integer i = queue.remove();
Integer j = queue.poll();
assertEquals(Integer.valueOf(1), i);
assertEquals(Integer.valueOf(2), j);
assertContents(queue, 3, 4);
}
@Test(timeout=1000) public void
shouldTakeAnItemWithoutWaiting() throws InterruptedException {
add1through4(queue);
long start = System.nanoTime();
Integer i = queue.take();
Integer j = queue.poll(10, TimeUnit.SECONDS);
long nanoSeconds = System.nanoTime() - start;
assertEquals(Integer.valueOf(1), i);
assertEquals(Integer.valueOf(2), j);
assertContents(queue, 3, 4);
assertTime(0, TimeUnit.NANOSECONDS.toMillis(nanoSeconds));
}
@Test public void
shouldDrainToCollection() {
add1through4(queue);
List<Number> list = new ArrayList<Number>();
int actualCount = queue.drainTo(list);
assertEquals(Arrays.asList(1, 2, 3, 4), list);
assertEquals("returned wrong size.", list.size(), actualCount);
assertContents(queue);
}
@Test public void
shouldDrainToCollectionWithMaxSize() {
add1through4(queue);
List<Number> list = new ArrayList<Number>();
int actualCount = queue.drainTo(list, 3);
assertEquals(Arrays.asList(1, 2, 3), list);
assertEquals("returned wrong size.", list.size(), actualCount);
assertContents(queue, 4);
}
@Test public void
shouldClearTheCollection() {
add1through4(queue);
queue.clear();
assertContents(queue);
}
@Test
public void theHeadLagsBehindTheTail() {
add1through4(queue);
queue.clear();
BlockingQueue<Integer> head = buffer.head();
assertContents(head, 1, 2, 3, 4);
}
@Test
public void theHeadOnlyLagsBehindByTheBufferSize() {
while (queue.size() <= BUFFER_SIZE)
add1through4(queue);
BlockingQueue<Integer> head = buffer.head();
assertEquals(BUFFER_SIZE, head.size());
}
@Test public void
theHeadIsTheTailWhenBufferSizeIsZero() {
buffer = new MultiHeadedQueue<Integer>(0);
queue = buffer.tail();
add1through4(queue);
assertContents(buffer.head());
}
@Test(timeout=1000) public void
willWaitToTake() throws InterruptedException, BrokenBarrierException {
final CyclicBarrier readyToGo = new CyclicBarrier(2);
addAfter200Milliseconds(readyToGo);
readyToGo.await();
long start = System.nanoTime();
int value = queue.take();
long nanoSeconds = System.nanoTime() - start;
assertEquals(1, value);
assertTime(200, TimeUnit.NANOSECONDS.toMillis(nanoSeconds));
}
@Test(timeout=1000) public void
willWaitToPoll() throws InterruptedException, BrokenBarrierException {
final CyclicBarrier readyToGo = new CyclicBarrier(2);
addAfter200Milliseconds(readyToGo);
readyToGo.await();
long start = System.nanoTime();
int value = queue.poll(10, TimeUnit.SECONDS);
long nanoSeconds = System.nanoTime() - start;
assertEquals(1, value);
assertTime(200, TimeUnit.NANOSECONDS.toMillis(nanoSeconds));
}
private void addAfter200Milliseconds(final CyclicBarrier readyToGo) {
invoke(new Callable<Void>() {
public Void call() throws Exception {
readyToGo.await();
Thread.sleep(200);
queue.add(1);
return null;
}
});
}
}
public static class MultipleQueues {
private MultiHeadedQueue<Integer> buffer = new MultiHeadedQueue<Integer>();
private BlockingQueue<Integer> queue1 = buffer.tail();
private BlockingQueue<Integer> queue2 = buffer.tail();
@Test public void
addingToOneAddsToBoth() {
add1through4(queue1);
assertContents(queue1, 1, 2, 3, 4);
assertContents(queue2, 1, 2, 3, 4);
}
@Test public void
takingFromOneDoesntAffectTheOther() {
add1through4(queue1);
queue1.remove();
assertContents(queue1, 2, 3, 4);
assertContents(queue2, 1, 2, 3, 4);
}
}
private static <T> void assertContents(BlockingQueue<? extends T> queue, T... xs) {
List<T> list = new ArrayList<T>(queue);
assertEquals(Arrays.asList(xs), list);
assertEquals(xs.length, queue.size());
assertEquals(xs.length == 0, queue.isEmpty());
}
public static void invoke(final Callable<Void> callable) {
Thread thread = new Thread(new Runnable() {
public void run() {
try {
callable.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
thread.isDaemon();
thread.start();
}
public static void add1through4(BlockingQueue<Integer> queue) {
queue.addAll(Arrays.asList(1, 2, 3, 4));
}
private static void assertTime(int expected, long millis) {
assertEquals(expected, millis, +50);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment