Skip to content

Instantly share code, notes, and snippets.

@saka1029
Last active February 11, 2021 13:06
Show Gist options
  • Save saka1029/7e576e3d77d65cf9df42b86643bd22bb to your computer and use it in GitHub Desktop.
Save saka1029/7e576e3d77d65cf9df42b86643bd22bb to your computer and use it in GitHub Desktop.
package test.puzzle.lambda;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
/**
* Javaで不動点コンビネータを活用してメモ化とトレース機能を実現する - Qiita
* https://qiita.com/saka1029/items/877e7e0d518625e47b23
*/
class TestFixedPointCombinator {
/*
不動点コンビネータ - Wikipedia
https://ja.wikipedia.org/wiki/%E4%B8%8D%E5%8B%95%E7%82%B9%E3%82%B3%E3%83%B3%E3%83%93%E3%83%8D%E3%83%BC%E3%82%BF#Z%E3%82%B3%E3%83%B3%E3%83%93%E3%83%8D%E3%83%BC%E3%82%BF
JavaScript ではこのように実装できる。
function Z(f) {
return (function(x) {
return f(function(y) {
return x(x)(y);
});
})(function(x) {
return f(function(y) {
return x(x)(y);
});
});
}
Z(function(f) { return function(n) { return n == 0 ? 1 : n * f(n - 1); }; })(5) == 120
*/
interface RecursiveFunction<F> extends Function<RecursiveFunction<F>, F> {
}
static <A, B> Function<A, B> Y(Function<Function<A, B>, Function<A, B>> f) {
RecursiveFunction<Function<A, B>> r = w -> f.apply(x -> w.apply(w).apply(x));
return r.apply(r);
}
/**
* Y combinator - Rosetta Code
* https://rosettacode.org/wiki/Y_combinator#Java_2
*/
@Test
void 不動点コンビネータのテスト() {
Function<Integer, Integer> fib = Y(
f -> n -> (n <= 2)
? 1
: (f.apply(n - 1) + f.apply(n - 2)));
Function<Integer, Integer> fac = Y(
f -> n -> (n <= 1)
? 1
: (n * f.apply(n - 1)));
System.out.println("fib(10) = " + fib.apply(10));
System.out.println("fac(10) = " + fac.apply(10));
}
static <T, R> Function<T, R> fixedPointCombinator(Function<Function<T, R>, Function<T, R>> f) {
return new Function<T, R>() {
@Override
public R apply(T t) {
return f.apply(this).apply(t);
}
};
}
static Function<Function<Integer, Integer>, Function<Integer, Integer>> factorial =
self -> n ->
n <= 0
? 1
: n * self.apply(n - 1);
@Test
void 単純化した不動点コンビネータのテスト() {
System.out.println("factorial(10) = " + fixedPointCombinator(factorial).apply(10));
}
static <T, R> Function<T, R> memoize(Function<Function<T, R>, Function<T, R>> f) {
return new Function<T, R>() {
final Map<T, R> cache = new HashMap<>();
@Override
public R apply(T t) {
R v = cache.get(t);
if (v == null)
cache.put(t, v = f.apply(this).apply(t));
return v;
// 以下の実装はConcurrentModificationExceptionがスローされる。
// return cache.computeIfAbsent(t, k -> f.apply(this).apply(k));
}
@Override
public String toString() {
return cache.toString();
}
};
}
static Function<Function<Integer, Integer>, Function<Integer, Integer>> fibonacci =
self -> n ->
n == 0 ? 0 :
n == 1 ? 1 :
self.apply(n - 1) + self.apply(n - 2);
@Test
void メモ化のテスト() {
Function<Integer, Integer> memoizedFibonacci = memoize(fibonacci);
System.out.println("fibonacci(10) = " + memoizedFibonacci.apply(10));
System.out.println(memoizedFibonacci);
}
static int tarai(int x, int y, int z) {
if (x <= y)
return y;
else
return tarai(tarai(x - 1, y, z),
tarai(y - 1, z, x),
tarai(z - 1, x, y));
}
static record Args(int x, int y, int z) {}
static Function<Function<Args, Integer>, Function<Args, Integer>> tarai =
self -> a ->
a.x <= a.y ?
a.y :
self.apply(new Args(self.apply(new Args(a.x - 1, a.y, a.z)),
self.apply(new Args(a.y - 1, a.z, a.x)),
self.apply(new Args(a.z - 1, a.x, a.y))));
@Test
void recordによる複数引数のメモ化() {
Function<Args, Integer> memoizedTarai = memoize(tarai);
System.out.println("tarai(3, 2, 1) = " + memoizedTarai.apply(new Args(3, 2, 1)));
System.out.println("キャッシュの中身: " + memoizedTarai);
}
static String 時間測定(Supplier<String> s) {
long start = System.currentTimeMillis();
return s.get() + " : 所要時間 " + (System.currentTimeMillis() - start) + "ms";
}
@Test
void 通常の関数とrecordによる複数引数のメモ化の性能比較() {
System.out.println(時間測定(() -> "通常の竹内関数 tarai(15, 7, 1) = " + tarai(15, 7, 1)));
System.out.println(時間測定(() -> "メモ化竹内関数(record) tarai(15, 7, 1) = " + memoize(tarai).apply(new Args(15, 7, 1))));
}
/*
*
tarai -> Function<Integer, Function<Integer, Function<Integer, Integer>>>
* tarai.apply(3) -> Function<Integer, Function<Integer, Integer>>
* tarai.apply(3).apply(2) -> Function<Integer, Integer>
* tarai.apply(3).apply(2).apply(1) -> Integer
*/
@Test
void カリー化による複数引数のメモ化() {
Function<Integer, Function<Integer, Function<Integer, Integer>>> tarai =
memoize(self -> x ->
memoize(selfy -> y ->
memoize(selfz -> z -> x <= y ? y
: self.apply(self.apply(x - 1).apply(y).apply(z))
.apply(self.apply(y - 1).apply(z).apply(x))
.apply(self.apply(z - 1).apply(x).apply(y)))));
System.out.println("tarai(3, 2, 1) = " + tarai.apply(3).apply(2).apply(1));
System.out.println("キャッシュの中身: " + tarai);
System.out.println(時間測定(() -> "メモ化竹内関数(カリー化) tarai(15, 7, 1) = " + tarai.apply(15).apply(7).apply(1)));
}
static <T, R> Function<T, R> trace(String name, Consumer<String> output, Function<Function<T, R>, Function<T, R>> f) {
return new Function<T, R>() {
int nest = 0;
@Override
public R apply(T t) {
String indent = " ".repeat(nest);
output.accept(indent + name + "(" + t + ")");
++nest;
R r = f.apply(this).apply(t);
--nest;
output.accept(indent + r);
return r;
}
};
}
@Test
void トレースのテスト() {
System.out.println("fibonacci(6) = " + trace("fibonacci", System.out::println, fibonacci).apply(6));
}
static <T, R> Function<T, R> memoizeTrace(String name, Consumer<String> output, Function<Function<T, R>, Function<T, R>> f) {
return new Function<T, R>() {
Map<T, R> cache = new HashMap<>();
int nest = 0;
@Override
public R apply(T t) {
String indent = " ".repeat(nest);
output.accept(indent + name + "(" + t + ")");
++nest;
R result = cache.get(t);
String from = "";
if (result == null)
cache.put(t, result = f.apply(this).apply(t));
else
from = " (cache)";
--nest;
output.accept(indent + result + from);
return result;
}
@Override
public String toString() {
return cache.toString();
}
};
}
@Test
void メモ化トレース() {
System.out.println("トレース fibonacci(6) = " + trace("fibonacci", System.out::println, fibonacci).apply(6));
System.out.println("メモ化トレース fibonacci(6) = " + memoizeTrace("fibonacci", System.out::println, fibonacci).apply(6));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment