Skip to content

Instantly share code, notes, and snippets.

@janodev
Created May 1, 2014 10:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save janodev/591fb77ca39369821aa4 to your computer and use it in GitHub Desktop.
Save janodev/591fb77ca39369821aa4 to your computer and use it in GitHub Desktop.
Factorial with tail-call optimization in Java 8
// Code from “Functional Programming in Java, Chapter 7”.
import java.util.function.Function;
import java.util.stream.Stream;
import java.math.BigInteger;
@FunctionalInterface
interface TailCall<T> {
abstract TailCall<T> apply();
default boolean isComplete() { return false; }
default T result() { throw new Error("not implemented"); }
default T invoke() {
return Stream.iterate(this, TailCall::apply)
.filter(TailCall::isComplete)
.findFirst()
.get()
.result();
}
}
class TailCalls {
public static <T> TailCall<T> done(final T value) {
return new TailCall<T>() {
@Override public boolean isComplete() {
return true;
}
@Override public T result() {
return value;
}
@Override public TailCall<T> apply() {
throw new Error("not implemented");
}
};
}
}
public class BigFactorial {
public static BigInteger decrement(final BigInteger number) {
return number.subtract(BigInteger.ONE);
}
public static BigInteger multiply(final BigInteger first, final BigInteger second) {
return first.multiply(second);
}
final static BigInteger ONE = BigInteger.ONE;
final static BigInteger TEN = BigInteger.TEN;
final static BigInteger TWENTYK = new BigInteger("20000");
private static TailCall<BigInteger> factorialTailRec(final BigInteger factorial, final BigInteger number) {
if(number.equals(BigInteger.ONE)){
return TailCalls.done(factorial);
} else {
return () -> factorialTailRec(multiply(factorial, number), decrement(number));
}
}
public static BigInteger factorial(final BigInteger number) {
return factorialTailRec(BigInteger.ONE, number).invoke();
}
public static void main(final String[] args) {
System.out.println(factorial(TEN));
System.out.println(String.format("%.10s...(first ten digits)", factorial(TWENTYK)));
// System.out.println(factorial(TWENTYK)); // prints all 77339 digits
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment