Skip to content

Instantly share code, notes, and snippets.

@saka1029
Last active December 7, 2016 09:11
Show Gist options
  • Save saka1029/8e10eaa3a10d9753064545e3df4a0a9a to your computer and use it in GitHub Desktop.
Save saka1029/8e10eaa3a10d9753064545e3df4a0a9a to your computer and use it in GitHub Desktop.
関数をメモ化する関数。カリー化すると引数が2個以上ある関数も処理できる。最後のテストはキャッシュを外部に取り出す例。
package stackoverflow;
import static org.junit.Assert.*;
import java.math.BigInteger;
import static java.math.BigInteger.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import org.junit.Test;
public class TestMemoize3 {
public static <T, U> Function<T, U> memoize(Function<Function<T, U>, Function<T, U>> f) {
return new Function<T, U>() {
final Map<T, U> cache = new HashMap<>();
final Function<T, U> body = f.apply(this);
@Override public U apply(T t) { return cache.computeIfAbsent(t, body); }
};
}
static BigInteger fact(BigInteger n) {
return n.compareTo(ZERO) <= 0 ? ONE :
n.multiply(fact(n.subtract(ONE)));
}
static Function<BigInteger, BigInteger> fact =
memoize(self -> n ->
n.compareTo(ZERO) <= 0 ? ONE :
n.multiply(self.apply(n.subtract(ONE))));
@Test
public void testFact() {
for (long i = 0; i < 1000; ++i)
System.out.println(fact.apply(BigInteger.valueOf(i)));
// System.out.println(fact(BigInteger.valueOf(i)));
}
static final BigInteger TWO = BigInteger.valueOf(2);
static BigInteger fibonacci(BigInteger n) {
return n.equals(ZERO) ? ZERO :
n.equals(ONE) ? ONE :
fibonacci(n.subtract(ONE)).add(fibonacci(n.subtract(TWO)));
}
static Function<BigInteger, BigInteger> fibonacci =
memoize(self -> n ->
n.equals(ZERO) ? ZERO :
n.equals(ONE) ? ONE :
self.apply(n.subtract(ONE)).add(self.apply(n.subtract(TWO)))
);
@Test
public void testFibonacci() {
for (long i = 0; i < 1000; ++i)
System.out.println(fibonacci.apply(BigInteger.valueOf(i)));
// System.out.println(fibonacci(BigInteger.valueOf(i)));
}
static int tarai(int x, int y, int z) {
return x <= y ? y :
tarai(tarai(x - 1, y, z),
tarai(y - 1, z, x),
tarai(z - 1, x, y));
}
static Function<Integer, Function<Integer, Function<Integer, Integer>>> tarai =
memoize(fx -> x ->
memoize(fy -> y ->
memoize(fz -> z ->
x <= y ? y :
fx.apply(fx.apply(x - 1).apply(y).apply(z))
.apply(fx.apply(y - 1).apply(z).apply(x))
.apply(fx.apply(z - 1).apply(x).apply(y)))));
@Test
public void testTarai() {
assertEquals(12, (int)tarai.apply(12).apply(6).apply(0));
assertEquals(13, (int)tarai.apply(13).apply(7).apply(0));
assertEquals(14, (int)tarai.apply(14).apply(8).apply(0));
assertEquals(15, (int)tarai.apply(15).apply(5).apply(0));
assertEquals(20, (int)tarai.apply(20).apply(10).apply(0));
}
interface CachedFunction<T, U> extends Function<T, U> {
Map<T, U> cache();
}
public static <T, U> CachedFunction<T, U> cached(Function<Function<T, U>, Function<T, U>> f) {
return new CachedFunction<T, U>() {
final Map<T, U> cache = new HashMap<>();
final Function<T, U> body = f.apply(this);
@Override public U apply(T t) { return cache.computeIfAbsent(t, body); }
@Override public Map<T, U> cache() { return Collections.unmodifiableMap(cache); }
};
}
static CachedFunction<Integer, CachedFunction<Integer, CachedFunction<Integer, Integer>>> cachedTarai =
cached(fx -> x ->
cached(fy -> y ->
cached(fz -> z ->
x <= y ? y :
fx.apply(fx.apply(x - 1).apply(y).apply(z))
.apply(fx.apply(y - 1).apply(z).apply(x))
.apply(fx.apply(z - 1).apply(x).apply(y)) )));
@Test
public void testCachedTarai() {
assertEquals(12, (int)cachedTarai.apply(12).apply(6).apply(0));
assertEquals(13, (int)cachedTarai.apply(13).apply(7).apply(0));
assertEquals(14, (int)cachedTarai.apply(14).apply(8).apply(0));
assertEquals(15, (int)cachedTarai.apply(15).apply(5).apply(0));
assertEquals(20, (int)cachedTarai.apply(20).apply(10).apply(0));
for (Entry<Integer, CachedFunction<Integer, CachedFunction<Integer, Integer>>> x : cachedTarai.cache().entrySet()) {
System.out.printf("x = %s%n", x.getKey());
for (Entry<Integer, CachedFunction<Integer, Integer>> y : x.getValue().cache().entrySet()) {
System.out.printf(" y = %s%n", y.getKey());
for (Entry<Integer, Integer> z : y.getValue().cache().entrySet()) {
System.out.printf(" z = %s -> %s%n", z.getKey(), z.getValue());
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment