Skip to content

Instantly share code, notes, and snippets.

@remibarat
Last active February 13, 2019 04:44
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save remibarat/889293cad02dc45a13d98aa4dbfe716e to your computer and use it in GitHub Desktop.
Save remibarat/889293cad02dc45a13d98aa4dbfe716e to your computer and use it in GitHub Desktop.
Difference in work stealing between jdk1.8 and jdk11. The following examples work well with jdk1.8 but do not finish with jdk11.
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
/**
* This example uses a Parent task (e.g. mapping in a custom Map when the Map is full) that will
* launch * a Child task (e.g. to resize the Map, we need to rehash it). Any new Parent task needs
* to wait for the Child task to finish. The Child task will also launch Child tasks.
*
* The idea is that, a thread B in a child task C0 `join`s a child task C2 and manages to steal a
* parent task P1 from thread A. However, P1 `join`s C0, so even though thread A finishes C2, B is
* blocked by itself.
*
* The following explains how this example behaves, on the left when executed in jdk11, and on the
* right when executed with jdk1.8. The code uses a `ForkJoinPool` with 2 threads, and
* `CountDownLatch`es to control work stealing. Note that, although we use a `ForkJoinPool` with
* 2 threads, with jdk1.8 we see a third thread coming to help when both thread A and thread B are
* blocked.
*
*
* ,-------, || ,--------,
* | jdk11 | || | jdk1.8 |
* '-------' || '--------'
* ||
* thread A thread B || thread A thread B thread C
* time -------- -------- || -------- -------- --------
* | starts P0 || starts P0
* | creates C0 || creates C0
* | +-- forks C0 <----- steals --+ || +-- forks C0 <----- steals --+
* | | starts C0 || | starts C0
* | | creates C2 || | creates C2
* | | ,-> forks C2 --+ || | forks C2 --+
* | | | invokes C1 || | ^ invokes C1
* | creates P1 | | || creates P1 | |
* | +-- forks P1 <-|-, | || +-- forks P1 <----------------|--,
* | joins C0 | | | || joins C0 | | |
* | +-- steals ----' | | || | | | '-- compensates --+
* | starts C2 | | || | | | starts P1
* | | | finishes C1 || | | | joins C0
* | | | | || | '---------|---------- steals --+
* | | | joins C2 || | | starts C2
* | | '-- steals --+ || | finishes C1 |
* | | starts P1 || | joins C2 -- nothing |
* | | joins C0 || | | to steal |
* | finishes C2 | || | | from C finishes C2
* | x x || | finishes C0 |
* | cannot proceed cannot proceed || | finishes C1
* | (join C0) even though || finishes P0
* | C2 finished ||
* v ||
*
* Outputs
*
* === Thread 13 is ForkJoinPool-1-worker-3 || === Thread 11 is ForkJoinPool-1-worker-1
* === Thread 14 is ForkJoinPool-1-worker-1 || === Thread 12 is ForkJoinPool-1-worker-0
* ||
* thread 13 | thread 14 || thread 11 | thread 12
* -----------|------------ || -----------|------------
* P0 starts | || P0 starts |
* | C0 starts || | C0 starts
* | C1 starts || | C1 starts
* C2 starts | || | | P1 starts (new Thread: 13)
* | C1 ends || | | C2 starts (new Thread: 13)
* | P1 starts || | | C2 ends (new Thread: 13)
* C2 ends | || | C1 ends
* || | C0 ends
* [Program does not finish] || | | P1 ends (new Thread: 13)
* || P0 ends |
*
*
* The main difference between jdk1.8 and jdk11 seems to be that, when thread A `join`s thread B:
* * In jdk1.8, if A still has a task (e.g. P1) in its work queue, a new thread (e.g. C) pops up and
* begins to perform A's task (C is "compensating" while A is blocked). It is only when the
* joining thread (e.g. C) has an empty work queue that it will steal from the thread it joins (e.g. B).
* * In jdk11, even though A still has a task in its work queue, A itself will steal from B.
* This is documented in `ForkJoinPool.awaitJoin(...)`: "First tries to locally helping, then scans
* other queues for a task produced by one of w's stealers; compensating and blocking if none are
* found.".
*/
public class TestForkJoin {
private static ForkJoinPool pool = new ForkJoinPool(2);
private static CountDownLatch invokedC1 = new CountDownLatch(1);
private static CountDownLatch invokedC2 = new CountDownLatch(1);
private static CountDownLatch startedP1 = new CountDownLatch(1);
/** Does not terminates in JDK11, but does in JDK1.8. */
public static void main(String[] args) {
pool.invoke(new ParentTask(0));
}
private static class ParentTask extends RecursiveAction {
final int id;
/** Created by P0, it must be performed before proceeding. */
static ChildTask subTask;
ParentTask(final int id) {
this.id = id;
}
@Override
protected void compute() {
printId("P" + this.id + " starts");
switch (this.id) {
case 0:
subTask = new ChildTask(0);
subTask.fork();
try { invokedC1.await(); } catch (InterruptedException e) { e.printStackTrace(); }
new ParentTask(1).fork();
subTask.join();
break;
case 1:
startedP1.countDown();
break;
}
subTask.join();
printId("P" + this.id + " ends ");
}
}
private static class ChildTask extends RecursiveAction {
final int id;
ChildTask(final int id) {
this.id = id;
}
@Override
protected void compute() {
printId("C" + this.id + " starts");
switch (this.id) {
case 0:
ChildTask c2 = new ChildTask(2);
c2.fork();
new ChildTask(1).invoke();
c2.join();
break;
case 1:
invokedC1.countDown();
try { invokedC2.await(); } catch (InterruptedException e) { e.printStackTrace(); }
break;
case 2:
invokedC2.countDown();
try { startedP1.await(); } catch (InterruptedException e) { e.printStackTrace(); }
break;
}
printId("C" + this.id + " ends ");
}
}
/*--------------*/
/* For printing */
/*--------------*/
private static long firstThreadId = -1;
private static long secondThreadId = -1;
private static String firstLine = null;
private static void printId(final String task) {
if (secondThreadId == -1) {
setIds();
if (firstLine == null) {
firstLine = getLine(task);
return;
} else {
System.out.println(firstLine);
}
}
System.out.println(getLine(task));
}
private static String getLine(final String task) {
long id = Thread.currentThread().getId();
if (id == firstThreadId) {
return " " + task + " |";
} else if (id == secondThreadId) {
return " | " + task;
} else {
return " | | " + task + " (new Thread: " + Thread.currentThread().getId() + ")";
}
}
private static synchronized void setIds() {
if (firstThreadId == -1) {
firstThreadId = Thread.currentThread().getId();
System.out.println("=== Thread " + firstThreadId + " is " + Thread.currentThread().getName());
} else {
secondThreadId = Thread.currentThread().getId();
System.out.println("=== Thread " + secondThreadId + " is " + Thread.currentThread().getName());
System.out.println();
System.out.println(" thread " + firstThreadId + " | thread " + secondThreadId);
System.out.println("-----------|------------");
}
}
}
/*
* (C) Quartet FS 2007-2019
* ALL RIGHTS RESERVED. This material is the CONFIDENTIAL and PROPRIETARY
* property of Quartet Financial Systems Limited. Any unauthorized use,
* reproduction or transfer of this material is strictly prohibited
*/
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;
/**
* This example is a simplified version of how a real application works. A {@link CopyAction} will
* create {@link WriteAction}s (e.g. W0, W1, W2) that concurrently write in a custom dictionary.
* When the dictionary is full, it must be resized and rehashed, which is performed concurrently by
* a {@link RehashAction} (R0 created by W0). The following {@link WriteAction} (e.g. W1, W2) that
* will try to write in the dictionary will need to wait for the {@link RehashAction} to finish, so
* they {@code join} it. This is where work stealing will occur.
*
* ,-------, || ,--------,
* | jdk11 | || | jdk1.8 |
* '-------' || '--------'
* ||
* thread A thread B || thread A thread B thread C
* time -------- -------- || -------- -------- --------
* | creates W0 | || creates W0 |
* | +-- forks W0 <------ steals --+ || +-- forks W0 <------ steals --+
* | | starts W0 || | starts W0
* | | starts R0 || | starts R0
* | | creates R1 || | creates R1
* | | ,-> forks R1 --+ || | forks R1 --+
* | | | invokes R2 || | invokes R2
* | creates W1 | | || creates W1 |
* | +-- forks W1 <-|-, | || +-- forks W1 <----------------|--,
* | invokes W2 | | | || invokes W2 | |
* | joins R0 | | | || joins R0 | |
* | +-- steals ----' | | || | | '--- compensates --+
* | starts R1 | | || | | starts W1
* | | | finishes R2 || | | joins R0 -- does not
* | | | joins R1 || | finishes R2 | steal
* | | '-- steals --+ || | joins R1 --- nothing | from B (?)
* | | starts W1 || | starts R1 to steal |
* | | joins R0 || | finishes R1 from C |
* | finishes R1 | || | finishes R0 |
* | x x || | finishes W0 finishes W1
* | both tasks cannot proceed || finishes W2
* | (join C0) even though ||
* | C2 finished ||
* v ||
*
* Outputs
*
* === Thread 13 is ForkJoinPool-1-worker-3 || === Thread 11 is ForkJoinPool-1-worker-1
* === Thread 14 is ForkJoinPool-1-worker-1 || === Thread 12 is ForkJoinPool-1-worker-0
* ||
* thread 13 | thread 14 || thread 11 | thread 12
* -----------|------------ || -----------|------------
* W2 starts | || W2 starts |
* | R0 starts || | R0 starts
* | R2 starts || | R2 starts
* R1 starts | || | | W1 starts (new Thread: 13)
* | R2 ends || | R2 ends
* R1 ends | || | R1 starts
* | W1 starts || | R1 ends
* || | R0 ends
* || | W0 ends
* || | | W1 ends (new Thread: 13)
* || W2 ends |
*
* Here, with jdk1.8, surprisingly, thread C does not steal from B when it could, although it stole C2
* in the other example.
*/
public class TestForkJoin2 {
private static ForkJoinPool pool = new ForkJoinPool(2);
private static CountDownLatch startedW1 = new CountDownLatch(1);
private static CountDownLatch invokedR2 = new CountDownLatch(1);
private static CountDownLatch joiningR1 = new CountDownLatch(1);
/** Does not terminates in JDK11, but does in JDK1.8. */
public static void main(String[] args) {
pool.invoke(new CopyAction());
}
private static class CopyAction extends RecursiveAction {
@Override
protected void compute() {
new WriteAction(0).fork();
new WriteAction(1).fork();
new WriteAction(2).invoke();
}
}
private static class WriteAction extends RecursiveAction {
final int id;
/** Created by P0, it must be performed before proceeding. */
static volatile RehashAction rehashAction;
WriteAction(final int id) {
this.id = id;
}
@Override
protected void compute() {
printId("W" + this.id + " starts");
switch (this.id) {
case 0:
rehashAction = new RehashAction(0);
rehashAction.invoke();
break;
case 1:
rehashAction.join();
break;
case 2:
try { invokedR2.await(); } catch (InterruptedException e) { e.printStackTrace(); }
rehashAction.join();
break;
}
printId("W" + this.id + " ends ");
}
}
private static class RehashAction extends RecursiveAction {
final int id;
RehashAction(final int id) {
this.id = id;
}
@Override
protected void compute() {
printId("R" + this.id + " starts");
switch (this.id) {
case 0:
RehashAction r1 = new RehashAction(1);
r1.fork();
new RehashAction(2).invoke();
joiningR1.countDown();
r1.join();
break;
case 1:
startedW1.countDown();
try { joiningR1.await(); } catch (InterruptedException e) { e.printStackTrace(); }
break;
case 2:
invokedR2.countDown();
try { startedW1.await(1, TimeUnit.SECONDS); } catch (InterruptedException e) { e.printStackTrace(); }
break;
}
printId("R" + this.id + " ends ");
}
}
/*--------------*/
/* For printing */
/*--------------*/
private static long firstThreadId = -1;
private static long secondThreadId = -1;
private static String firstLine = "";
private synchronized static void printId(final String task) {
if (secondThreadId == -1) {
setIds(task);
} else {
System.out.println(getLine(task));
}
}
private static String getLine(final String task) {
long id = Thread.currentThread().getId();
if (id == firstThreadId) {
return " " + task + " |";
} else if (id == secondThreadId) {
return " | " + task;
} else {
return " | | " + task + " (new Thread: " + Thread.currentThread().getId() + ")";
}
}
private static void setIds(final String task) {
final long id = Thread.currentThread().getId();
if (firstThreadId == -1) {
firstThreadId = id;
System.out.println("=== Thread " + firstThreadId + " is " + Thread.currentThread().getName());
firstLine += getLine(task);
} else if (firstThreadId == id) {
firstLine += "\n" + getLine(task);
} else {
secondThreadId = Thread.currentThread().getId();
System.out.println("=== Thread " + secondThreadId + " is " + Thread.currentThread().getName());
System.out.println();
System.out.println(" thread " + firstThreadId + " | thread " + secondThreadId);
System.out.println("-----------|------------");
System.out.println(firstLine);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment