Skip to content

Instantly share code, notes, and snippets.

@shenfeng
Created February 26, 2022 10:38
Show Gist options
  • Save shenfeng/b08b69b705ac0dca44bcabb09b8460c7 to your computer and use it in GitHub Desktop.
Save shenfeng/b08b69b705ac0dca44bcabb09b8460c7 to your computer and use it in GitHub Desktop.
lock free concurrent queue java
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
public class LockFreeQueue<E> {
private static final VarHandle NEXT;
private static final VarHandle HEAD;
private static final VarHandle TAIL;
static {
try {
MethodHandles.Lookup l = MethodHandles.lookup();
NEXT = l.findVarHandle(Node.class, "next", Node.class);
HEAD = l.findVarHandle(LockFreeQueue.class, "head", Node.class);
TAIL = l.findVarHandle(LockFreeQueue.class, "tail", Node.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
static final class Node<E> {
E item;
volatile Node<E> next;
Node() {
}
Node(E item) {
this.item = item;
}
boolean casNext(Node<E> o, Node<E> n) {
return NEXT.compareAndSet(this, o, n);
}
}
private volatile Node<E> head;
private volatile Node<E> tail;
public LockFreeQueue() {
head = tail = new Node<>();
}
public void offer(E e) {
Node<E> node = new Node<>(e);
Node<E> t;
while (true) {
t = tail;
Node<E> n = t.next;
if (n == null) {
if (t.casNext(null, node)) {
TAIL.compareAndSet(this, t, node);
break;
}
} else {
TAIL.compareAndSet(this, t, n);
}
}
}
public E poll() {
while (true) {
Node<E> h = head;
Node<E> t = tail;
if (h == t) {
return null;
} else {
Node<E> n = h.next;
E e = n.item;
if (HEAD.compareAndSet(this, h, n)) {
return e;
}
}
}
}
}
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicIntegerArray;
public class LockFreeQueueTest {
final static int N = 200000;
@Test
public void testSingleThread() {
LockFreeQueue<Integer> q = new LockFreeQueue<>();
for (int i = 0; i < N; i++) {
q.offer(i);
Assert.assertEquals(i, q.poll().intValue());
Assert.assertNull(q.poll());
}
}
@Test
public void TestJDK() {
ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>();
q.offer(1);
q.offer(2);
q.offer(3);
q.offer(4);
}
@Test
public void testTwoThread() throws InterruptedException {
LockFreeQueue<Integer> q = new LockFreeQueue<>();
// ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>();
int N = 1000000;
for (int w = 0; w < 10; w++) {
Thread t = new Thread(() -> {
for (int i = 0; i < N; i++) {
while (true) {
Integer v = q.poll();
if (v != null) {
Assert.assertEquals(v.intValue(), i);
break;
}
}
}
});
t.start();
new Thread(() -> {
for (int i = 0; i < N; i++) {
q.offer(i);
}
}).start();
t.join();
System.out.println();
}
}
@Test
public void testNThread() throws InterruptedException {
AtomicIntegerArray array = new AtomicIntegerArray(N);
for (int i = 0; i < 8; i++) {
testMPMC(i + 1, array);
}
}
private void testMPMC(int n, AtomicIntegerArray array) throws InterruptedException {
for (int i = 0; i < N; i++) {
array.set(i, 0);
}
List<Thread> threads = new ArrayList<>();
LockFreeQueue<Integer> q = new LockFreeQueue<>();
// ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>();
for (int i = 0; i < n; i++) {
Thread c = new Thread(() -> { // consumer
for (int j = 0; j < N; j++) {
while (true) {
Integer v = q.poll();
if (v != null) {
array.incrementAndGet(v);
break;
}
}
}
});
threads.add(c);
Thread t = new Thread(() -> { // producer
for (int j = 0; j < N; j++) {
q.offer(j);
}
});
threads.add(t);
t.start();
c.start();
}
for (Thread t : threads) {
t.join();
}
for (int i = 0; i < N; i++) {
if (n != array.getPlain(i)) {
Assert.fail("idx: " + i + ", expect: " + n + ", get: " + array.getPlain(i));
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment