Created
March 2, 2017 01:49
-
-
Save jhump/2dc67844dc57ba510d305c56aa640440 to your computer and use it in GitHub Desktop.
PriorityBlockingQueue with O(n) implementation of removeIf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.reflect.Field; | |
import java.lang.reflect.InvocationTargetException; | |
import java.lang.reflect.Method; | |
import java.util.Arrays; | |
import java.util.Collection; | |
import java.util.Comparator; | |
import java.util.Random; | |
import java.util.concurrent.PriorityBlockingQueue; | |
import java.util.concurrent.ThreadLocalRandom; | |
import java.util.concurrent.locks.Lock; | |
import java.util.function.Predicate; | |
public class PriorityBlockingQueue2<E> extends PriorityBlockingQueue<E> { | |
private static final Method heapifyMethod; | |
private static final Field queueField; | |
private static final Field lockField; | |
private static final Field sizeField; | |
static { | |
try { | |
heapifyMethod = PriorityBlockingQueue.class.getDeclaredMethod("heapify"); | |
heapifyMethod.setAccessible(true); | |
queueField = PriorityBlockingQueue.class.getDeclaredField("queue"); | |
queueField.setAccessible(true); | |
lockField = PriorityBlockingQueue.class.getDeclaredField("lock"); | |
lockField.setAccessible(true); | |
sizeField = PriorityBlockingQueue.class.getDeclaredField("size"); | |
sizeField.setAccessible(true); | |
} catch (ReflectiveOperationException e) { | |
throw new AssertionError(e); | |
} | |
} | |
private final Lock lock; | |
PriorityBlockingQueue2() { | |
super(); | |
lock = (Lock) getField(lockField); | |
} | |
PriorityBlockingQueue2(int initialCapacity) { | |
super(initialCapacity); | |
lock = (Lock) getField(lockField); | |
} | |
PriorityBlockingQueue2(int initialCapacity, Comparator<? super E> comp) { | |
super(initialCapacity, comp); | |
lock = (Lock) getField(lockField); | |
} | |
PriorityBlockingQueue2(Collection<? extends E> coll) { | |
super(coll); | |
lock = (Lock) getField(lockField); | |
} | |
private Object getField(Field f) { | |
try { | |
return f.get(this); | |
} catch (IllegalAccessException e) { | |
throw new AssertionError(e); | |
} | |
} | |
@Override public boolean removeIf(Predicate<? super E> filter) { | |
lock.lock(); | |
try { | |
final Object[] array = (Object[]) getField(queueField); | |
final int sz = size(); | |
int j = 0; | |
for (int i = 0; i < sz; i++) { | |
@SuppressWarnings("unchecked") | |
E elem = (E) array[i]; | |
if (!filter.test(elem)) { | |
// keeping this item | |
if (i != j) { | |
// shift elements to fill holes left by any removed items | |
array[j] = array[i]; | |
} | |
j++; | |
} | |
} | |
if (j < sz) { | |
// don't leave dangling references in tail of queue | |
Arrays.fill(array, j, sz, null); | |
// restore heap invariants | |
sizeField.set(this, j); | |
heapifyMethod.invoke(this); | |
return true; | |
} else { | |
return false; | |
} | |
} catch (IllegalAccessException | InvocationTargetException e) { | |
// Not good! But if static init and constructor executed without | |
// exception, then we should never get here. | |
throw new AssertionError(); | |
} finally { | |
lock.unlock(); | |
} | |
} | |
// On my Macbook, the JRE version, which has O(n^2) runtime, takes | |
// about a minute to prune ~50% of the queue (which starts off at | |
// 1 million elements). This "fixed" version only takes less than 30 | |
// milliseconds (up to ~50 before it's warmed up). So orders of | |
// magnitude faster (as expected). | |
public static void main(String[] args) { | |
PriorityBlockingQueue<Long> pbq = new PriorityBlockingQueue<>(); | |
PriorityBlockingQueue2<Long> pbq2 = new PriorityBlockingQueue2<>(); | |
while (true) { | |
for (int i = 0; i < 1_000_000; i++) { | |
long l = ThreadLocalRandom.current().nextLong(); | |
pbq.add(l); | |
pbq2.add(l); | |
} | |
long start = System.nanoTime(); | |
boolean b = pbq.removeIf(l -> l > 0); | |
long end = System.nanoTime(); | |
System.out.printf("pbq: %fms\n", (end - start)/1_000_000.0); | |
start = System.nanoTime(); | |
boolean b2 = pbq2.removeIf(l -> l > 0); | |
end = System.nanoTime(); | |
System.out.printf("pbq2: %fms\n", (end - start)/1_000_000.0); | |
if (b != b2) { | |
throw new AssertionError("pbq2 returned " + b2 + " but should have returned " + b); | |
} | |
if (pbq.size() != pbq2.size()) { | |
throw new AssertionError("pbq2 has " + pbq2.size() + " items but should have " + pbq.size()); | |
} | |
while (!pbq.isEmpty()) { | |
long l = pbq.remove(); | |
long l2 = pbq2.remove(); | |
if (l != l2) { | |
throw new AssertionError("pbq2 popped " + l2 + " but should have been " + l); | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment