Skip to content

Instantly share code, notes, and snippets.

@rsrini7
Forked from smillies/FibCached.java
Created January 21, 2019 16:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rsrini7/42a9408616daf7b6d216125a5a1333af to your computer and use it in GitHub Desktop.
Save rsrini7/42a9408616daf7b6d216125a5a1333af to your computer and use it in GitHub Desktop.
Caching recursive functions
package java8.concurrent;
import java.math.BigInteger;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.Function;
/**
* Demonstrates ways of caching recursive functions.
*
* @author Andrew Haley, Viktor Klang, Sebastian Millies
* @see triggered by <a href=
* "http://concurrency.markmail.org/search/?q=#query:%20list%3Aedu.oswego.cs.concurrency-interest+page:3+mid:tf7xddfa6i6ow6d3+state:results">
* this discussion</a> on concurrency-interest
*
*/
public class FibCached {
private static class Memoizer<T, R> {
private final Map<T, R> memo;
public Memoizer(Map<T, R> memo) {
this.memo = memo;
}
public Function<T, R> memoize(Function<T, R> f) {
return t -> {
R r = memo.get(t);
if (r == null) {
r = f.apply(t);
memo.put(t, r);
}
return r;
};
}
}
public static class FibonacciSimple {
private final Memoizer<Integer, BigInteger> m;
public FibonacciSimple(Map<Integer, BigInteger> cache) {
m = new Memoizer<>(cache);
}
public BigInteger fib(int n) {
if (n <= 2) return BigInteger.ONE;
return m.memoize(this::fib).apply(n - 1).add(
m.memoize(this::fib).apply(n - 2));
}
}
public static class FibonacciCF {
private final Map<Integer, CompletionStage<BigInteger>> cache;
public FibonacciCF(Map<Integer, CompletionStage<BigInteger>> cache) {
this.cache = cache;
}
public CompletionStage<BigInteger> fib(int n) {
if (n <= 2) return CompletableFuture.completedFuture(BigInteger.ONE);
CompletionStage<BigInteger> ret = cache.get(n);
if (ret == null) {
final CompletableFuture<BigInteger> compute = new CompletableFuture<>();
ret = cache.putIfAbsent(n, compute);
if (ret == null) {
ret = fib(n - 1).thenCompose(x ->
fib(n - 2).thenCompose(y -> {
compute.complete(x.add(y));
return compute;
}));
}
}
return ret;
}
// async version. It's very much possible and recommended to not make the first thenCompose an async one,
// as only the addition of x and y might be "expensive" (for large values).
public CompletionStage<BigInteger> fib(int n, Executor e) {
if (n <= 2) return CompletableFuture.completedFuture(BigInteger.ONE);
CompletionStage<BigInteger> ret = cache.get(n);
if (ret == null) {
final CompletableFuture<BigInteger> compute = new CompletableFuture<>();
ret = cache.putIfAbsent(n, compute);
if (ret == null) {
ret = fib(n - 1, e).thenCompose(x ->
fib(n - 2, e).thenComposeAsync(y -> {
compute.complete(x.add(y));
return compute;
}, e));
}
}
return ret;
}
}
}
package java8.concurrent;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java8.concurrent.FibCached.FibonacciCF;
import java8.concurrent.FibCached.FibonacciSimple;
import java8.concurrent.eventdriven.util.ThreadPools;
@BenchmarkMode({ Mode.AverageTime })
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 10, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 20, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Fork(1)
public class FibCachedBenchmark {
private static final int POOL_SIZE = 1; // increasing the pool size will make performance worse
private ExecutorService pool;
@Setup(Level.Trial)
public void setup() throws InterruptedException {
System.out.println("POOL_SIZE = " + POOL_SIZE);
pool = Executors.newFixedThreadPool(POOL_SIZE, ThreadPools.newThreadFactory("fibonacci-%d", true));
}
@TearDown
public void tearDown() throws InterruptedException {
pool.shutdownNow();
pool.awaitTermination(1, TimeUnit.SECONDS);
}
@Benchmark
public BigInteger simple2000() {
return new FibonacciSimple( new HashMap<>()).fib(2000);
}
@Benchmark
public BigInteger cf2000() {
return new FibonacciCF( new HashMap<>()).fib(2000).toCompletableFuture().join();
}
@Benchmark
public BigInteger cfAsync2000() {
return new FibonacciCF( new ConcurrentHashMap<>()).fib(2000, pool).toCompletableFuture().join();
}
public static void main(String[] args) throws RunnerException {
Locale.setDefault(Locale.ENGLISH);
Options opt = new OptionsBuilder().verbosity(VerboseMode.NORMAL)
.include(".*" + FibCachedBenchmark.class.getSimpleName() + ".*").build();
new Runner(opt).run();
}
}
package java8.concurrent;
import java.math.BigInteger;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;
import java8.concurrent.FibCached.FibonacciCF;
import java8.concurrent.eventdriven.util.ThreadPools;
@BenchmarkMode({ Mode.SingleShotTime })
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Warmup(iterations = 10, time = 2000, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 20, time = 2000, timeUnit = TimeUnit.MILLISECONDS)
@Fork(1)
public class FibCachedConcurrentBenchmark {
@State(Scope.Benchmark)
public static class PoolHolder {
private static final int POOL_SIZE = 2;
ExecutorService pool;
@Setup(Level.Trial)
public void createPool() throws InterruptedException {
System.out.println("POOL_SIZE = " + POOL_SIZE);
pool = Executors.newFixedThreadPool(POOL_SIZE, ThreadPools.newThreadFactory("fibonacci-%d", true));
}
@TearDown
public void shutdown() throws InterruptedException {
pool.shutdownNow();
pool.awaitTermination(1, TimeUnit.SECONDS);
}
}
@State(Scope.Benchmark)
public static class FibHolder {
FibonacciCF fibCF;
@Setup(Level.Iteration)
public void createFib() throws InterruptedException {
fibCF = new FibonacciCF(new ConcurrentHashMap<>());
}
}
@Benchmark
@Threads(8)
public BigInteger cf3500(FibHolder fibHolder, PoolHolder poolHolder) {
return fibHolder.fibCF.fib(3500).toCompletableFuture().join();
}
@Benchmark
@Threads(8)
public BigInteger cfAsync3500(FibHolder fibHolder, PoolHolder poolHolder) {
return fibHolder.fibCF.fib(3500, poolHolder.pool).toCompletableFuture().join();
}
public static void main(String[] args) throws RunnerException {
Locale.setDefault(Locale.ENGLISH);
Options opt = new OptionsBuilder().verbosity(VerboseMode.NORMAL)
.include(".*" + FibCachedConcurrentBenchmark.class.getSimpleName() + ".*").build();
new Runner(opt).run();
}
}
package java8.concurrent.eventdriven.util;
import java.util.concurrent.ThreadFactory;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
public abstract class ThreadPools {
public static ThreadFactory newThreadFactory(String nameFormat, boolean createDaemons) {
return new ThreadFactoryBuilder().setNameFormat(nameFormat).setDaemon(createDaemons).build();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment