Skip to content

Instantly share code, notes, and snippets.

@smillies
Last active January 21, 2019 16:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save smillies/0cceb17501f74c4f53bf4930eba61889 to your computer and use it in GitHub Desktop.
Save smillies/0cceb17501f74c4f53bf4930eba61889 to your computer and use it in GitHub Desktop.
Caching recursive functions
package java8.concurrent;
import java.math.BigInteger;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.Function;
/**
* Demonstrates ways of caching recursive functions.
*
* @author Andrew Haley, Viktor Klang, Sebastian Millies
* @see triggered by <a href=
* "http://concurrency.markmail.org/search/?q=#query:%20list%3Aedu.oswego.cs.concurrency-interest+page:3+mid:tf7xddfa6i6ow6d3+state:results">
* this discussion</a> on concurrency-interest
*
*/
public class FibCached {
private static class Memoizer<T, R> {
private final Map<T, R> memo;
public Memoizer(Map<T, R> memo) {
this.memo = memo;
}
public Function<T, R> memoize(Function<T, R> f) {
return t -> {
R r = memo.get(t);
if (r == null) {
r = f.apply(t);
memo.put(t, r);
}
return r;
};
}
}
public static class FibonacciSimple {
private final Memoizer<Integer, BigInteger> m;
public FibonacciSimple(Map<Integer, BigInteger> cache) {
m = new Memoizer<>(cache);
}
public BigInteger fib(int n) {
if (n <= 2) return BigInteger.ONE;
return m.memoize(this::fib).apply(n - 1).add(
m.memoize(this::fib).apply(n - 2));
}
}
public static class FibonacciCF {
private final Map<Integer, CompletionStage<BigInteger>> cache;
public FibonacciCF(Map<Integer, CompletionStage<BigInteger>> cache) {
this.cache = cache;
}
public CompletionStage<BigInteger> fib(int n) {
if (n <= 2) return CompletableFuture.completedFuture(BigInteger.ONE);
CompletionStage<BigInteger> ret = cache.get(n);
if (ret == null) {
final CompletableFuture<BigInteger> compute = new CompletableFuture<>();
ret = cache.putIfAbsent(n, compute);
if (ret == null) {
ret = fib(n - 1).thenCompose(x ->
fib(n - 2).thenCompose(y -> {
compute.complete(x.add(y));
return compute;
}));
}
}
return ret;
}
// async version. It's very much possible and recommended to not make the first thenCompose an async one,
// as only the addition of x and y might be "expensive" (for large values).
public CompletionStage<BigInteger> fib(int n, Executor e) {
if (n <= 2) return CompletableFuture.completedFuture(BigInteger.ONE);
CompletionStage<BigInteger> ret = cache.get(n);
if (ret == null) {
final CompletableFuture<BigInteger> compute = new CompletableFuture<>();
ret = cache.putIfAbsent(n, compute);
if (ret == null) {
ret = fib(n - 1, e).thenCompose(x ->
fib(n - 2, e).thenComposeAsync(y -> {
compute.complete(x.add(y));
return compute;
}, e));
}
}
return ret;
}
}
}
package java8.concurrent;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java8.concurrent.FibCached.FibonacciCF;
import java8.concurrent.FibCached.FibonacciSimple;
import java8.concurrent.eventdriven.util.ThreadPools;
@BenchmarkMode({ Mode.AverageTime })
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 10, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 20, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Fork(1)
public class FibCachedBenchmark {
private static final int POOL_SIZE = 1; // increasing the pool size will make performance worse
private ExecutorService pool;
@Setup(Level.Trial)
public void setup() throws InterruptedException {
System.out.println("POOL_SIZE = " + POOL_SIZE);
pool = Executors.newFixedThreadPool(POOL_SIZE, ThreadPools.newThreadFactory("fibonacci-%d", true));
}
@TearDown
public void tearDown() throws InterruptedException {
pool.shutdownNow();
pool.awaitTermination(1, TimeUnit.SECONDS);
}
@Benchmark
public BigInteger simple2000() {
return new FibonacciSimple( new HashMap<>()).fib(2000);
}
@Benchmark
public BigInteger cf2000() {
return new FibonacciCF( new HashMap<>()).fib(2000).toCompletableFuture().join();
}
@Benchmark
public BigInteger cfAsync2000() {
return new FibonacciCF( new ConcurrentHashMap<>()).fib(2000, pool).toCompletableFuture().join();
}
public static void main(String[] args) throws RunnerException {
Locale.setDefault(Locale.ENGLISH);
Options opt = new OptionsBuilder().verbosity(VerboseMode.NORMAL)
.include(".*" + FibCachedBenchmark.class.getSimpleName() + ".*").build();
new Runner(opt).run();
}
}
package java8.concurrent;
import java.math.BigInteger;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java8.concurrent.FibCached.FibonacciCF;
import java8.concurrent.eventdriven.util.ThreadPools;
@BenchmarkMode({ Mode.SingleShotTime })
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Warmup(iterations = 10, time = 2000, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 20, time = 2000, timeUnit = TimeUnit.MILLISECONDS)
@Fork(1)
public class FibCachedConcurrentBenchmark {
@State(Scope.Benchmark)
public static class PoolHolder {
private static final int POOL_SIZE = 2;
ExecutorService pool;
@Setup(Level.Trial)
public void createPool() throws InterruptedException {
System.out.println("POOL_SIZE = " + POOL_SIZE);
pool = Executors.newFixedThreadPool(POOL_SIZE, ThreadPools.newThreadFactory("fibonacci-%d", true));
}
@TearDown
public void shutdown() throws InterruptedException {
pool.shutdownNow();
pool.awaitTermination(1, TimeUnit.SECONDS);
}
}
@State(Scope.Benchmark)
public static class FibHolder {
FibonacciCF fibCF;
@Setup(Level.Iteration)
public void createFib() throws InterruptedException {
fibCF = new FibonacciCF(new ConcurrentHashMap<>());
}
}
@Benchmark
@Threads(8)
public BigInteger cf3500(FibHolder fibHolder, PoolHolder poolHolder) {
return fibHolder.fibCF.fib(3500).toCompletableFuture().join();
}
@Benchmark
@Threads(8)
public BigInteger cfAsync3500(FibHolder fibHolder, PoolHolder poolHolder) {
return fibHolder.fibCF.fib(3500, poolHolder.pool).toCompletableFuture().join();
}
public static void main(String[] args) throws RunnerException {
Locale.setDefault(Locale.ENGLISH);
Options opt = new OptionsBuilder().verbosity(VerboseMode.NORMAL)
.include(".*" + FibCachedConcurrentBenchmark.class.getSimpleName() + ".*").build();
new Runner(opt).run();
}
}
package java8.concurrent.eventdriven.util;
import java.util.concurrent.ThreadFactory;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
public abstract class ThreadPools {
public static ThreadFactory newThreadFactory(String nameFormat, boolean createDaemons) {
return new ThreadFactoryBuilder().setNameFormat(nameFormat).setDaemon(createDaemons).build();
}
}
@smillies
Copy link
Author

A few measurements with JMH 1.12 and Java 1.8.0_92 on my Windows 7 Enterprise 4-core (Intel i7-4800MQ CPU):

# Run complete. Total time: 00:01:32

Benchmark                       Mode  Cnt  Score   Error  Units
FibCachedBenchmark.cf2000       avgt   20  0.201 ± 0.002  ms/op
FibCachedBenchmark.cfAsync2000  avgt   20  0.300 ± 0.002  ms/op
FibCachedBenchmark.simple2000   avgt   20  0.172 ± 0.002  ms/op

# Run complete. Total time: 00:01:31

Benchmark                       Mode  Cnt  Score   Error  Units
FibCachedBenchmark.cf2000       avgt   20  0.205 ± 0.002  ms/op
FibCachedBenchmark.cfAsync2000  avgt   20  0.304 ± 0.002  ms/op
FibCachedBenchmark.simple2000   avgt   20  0.174 ± 0.001  ms/op

@smillies
Copy link
Author

smillies commented Apr 28, 2016

When running into a StackOverflowError with the recursion, be advised that when running under JMH you may not see the error, just your benchmark hanging and the JMH worker threads all being parked.

@smillies
Copy link
Author

And here a few measurements with the concurrent benchmark. In that benchmark, we use the same fib-instance to see how it behaves under contention. The benchmark uses single-shot mode, because averaging multiple calls on that same instance wouldn't make sense: all but the first call are simple map lookups.(We cannot test the simple version in this way, because it is not thread-safe.)

I consistently do not see any advantage of the async version over the synchronous one. In fact, it is the other way around.

1 Thread
Benchmark                                 Mode  Cnt  Score   Error  Units
FibCachedConcurrentBenchmark.cf3500         ss   20  1.969 ± 0.427  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  9.009 ± 3.032  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20  2.110 ± 1.052  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  8.582 ± 2.741  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20  2.152 ± 0.366  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  9.750 ± 4.466  ms/op

2 Threads
Benchmark                                 Mode  Cnt  Score     Error  Units
FibCachedConcurrentBenchmark.cf3500         ss   20   4.132 ±  1.421  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20   9.134 ±  0.862  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20   2.887 ±  0.571  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  10.345 ± 12.954  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20   3.500 ±  1.291  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20   8.803 ±  1.679  ms/op

4 Threads
Benchmark                                 Mode  Cnt  Score   Error  Units
FibCachedConcurrentBenchmark.cf3500         ss   20  2.780 ± 0.430  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  8.850 ± 1.595  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20  3.034 ± 0.451  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  9.744 ± 1.669  ms/op
FibCachedConcurrentBenchmark.cf3500         ss   20  3.965 ± 1.380  ms/op
FibCachedConcurrentBenchmark.cfAsync3500    ss   20  8.430 ± 2.396  ms/op

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment