Skip to content

Instantly share code, notes, and snippets.

@jhump
Created March 2, 2017 01:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jhump/2dc67844dc57ba510d305c56aa640440 to your computer and use it in GitHub Desktop.
Save jhump/2dc67844dc57ba510d305c56aa640440 to your computer and use it in GitHub Desktop.
PriorityBlockingQueue with O(n) implementation of removeIf
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Random;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.locks.Lock;
import java.util.function.Predicate;
public class PriorityBlockingQueue2<E> extends PriorityBlockingQueue<E> {
private static final Method heapifyMethod;
private static final Field queueField;
private static final Field lockField;
private static final Field sizeField;
static {
try {
heapifyMethod = PriorityBlockingQueue.class.getDeclaredMethod("heapify");
heapifyMethod.setAccessible(true);
queueField = PriorityBlockingQueue.class.getDeclaredField("queue");
queueField.setAccessible(true);
lockField = PriorityBlockingQueue.class.getDeclaredField("lock");
lockField.setAccessible(true);
sizeField = PriorityBlockingQueue.class.getDeclaredField("size");
sizeField.setAccessible(true);
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
}
}
private final Lock lock;
PriorityBlockingQueue2() {
super();
lock = (Lock) getField(lockField);
}
PriorityBlockingQueue2(int initialCapacity) {
super(initialCapacity);
lock = (Lock) getField(lockField);
}
PriorityBlockingQueue2(int initialCapacity, Comparator<? super E> comp) {
super(initialCapacity, comp);
lock = (Lock) getField(lockField);
}
PriorityBlockingQueue2(Collection<? extends E> coll) {
super(coll);
lock = (Lock) getField(lockField);
}
private Object getField(Field f) {
try {
return f.get(this);
} catch (IllegalAccessException e) {
throw new AssertionError(e);
}
}
@Override public boolean removeIf(Predicate<? super E> filter) {
lock.lock();
try {
final Object[] array = (Object[]) getField(queueField);
final int sz = size();
int j = 0;
for (int i = 0; i < sz; i++) {
@SuppressWarnings("unchecked")
E elem = (E) array[i];
if (!filter.test(elem)) {
// keeping this item
if (i != j) {
// shift elements to fill holes left by any removed items
array[j] = array[i];
}
j++;
}
}
if (j < sz) {
// don't leave dangling references in tail of queue
Arrays.fill(array, j, sz, null);
// restore heap invariants
sizeField.set(this, j);
heapifyMethod.invoke(this);
return true;
} else {
return false;
}
} catch (IllegalAccessException | InvocationTargetException e) {
// Not good! But if static init and constructor executed without
// exception, then we should never get here.
throw new AssertionError();
} finally {
lock.unlock();
}
}
// On my Macbook, the JRE version, which has O(n^2) runtime, takes
// about a minute to prune ~50% of the queue (which starts off at
// 1 million elements). This "fixed" version only takes less than 30
// milliseconds (up to ~50 before it's warmed up). So orders of
// magnitude faster (as expected).
public static void main(String[] args) {
PriorityBlockingQueue<Long> pbq = new PriorityBlockingQueue<>();
PriorityBlockingQueue2<Long> pbq2 = new PriorityBlockingQueue2<>();
while (true) {
for (int i = 0; i < 1_000_000; i++) {
long l = ThreadLocalRandom.current().nextLong();
pbq.add(l);
pbq2.add(l);
}
long start = System.nanoTime();
boolean b = pbq.removeIf(l -> l > 0);
long end = System.nanoTime();
System.out.printf("pbq: %fms\n", (end - start)/1_000_000.0);
start = System.nanoTime();
boolean b2 = pbq2.removeIf(l -> l > 0);
end = System.nanoTime();
System.out.printf("pbq2: %fms\n", (end - start)/1_000_000.0);
if (b != b2) {
throw new AssertionError("pbq2 returned " + b2 + " but should have returned " + b);
}
if (pbq.size() != pbq2.size()) {
throw new AssertionError("pbq2 has " + pbq2.size() + " items but should have " + pbq.size());
}
while (!pbq.isEmpty()) {
long l = pbq.remove();
long l2 = pbq2.remove();
if (l != l2) {
throw new AssertionError("pbq2 popped " + l2 + " but should have been " + l);
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment