Skip to content

Instantly share code, notes, and snippets.

@Stupremee
Created March 30, 2019 20:58
Show Gist options
  • Save Stupremee/650bdba32118fbe6e1bbf32ea7c99196 to your computer and use it in GitHub Desktop.
Save Stupremee/650bdba32118fbe6e1bbf32ea7c99196 to your computer and use it in GitHub Desktop.
A "talk" from GrandPanda

@Dannie well yes but actually no. The problem is still that when the body of a lambda is executed, a stack frame is allocated. The point of trampoling is to not do that essentially you have some run method of Trampoline which is iterative or tail recursive and by suspending all of your functions in that Trampoline type (rather than executing them in place), you can make the execution of the original function only take one stack frame it's all very roundabout and actaully something I didnt really think about is that you can manually write run to be iterative, not relying on there being tail call optimization. That means it's possible in java as well @tterrag

yeah here we go

interface Trampoline<T> {
    default T run() {
        Trampoline<T> trampoline = this;
        while (!(trampoline instanceof Done)) {
            trampoline = ((More<T>) trampoline).f.get();
        }
        return ((Done<T>) trampoline).result;
    }
}
static class Done<T> implements Trampoline<T> {
    final T result;

    Done(T result) {
        this.result = result;
    }
}
static class More<T> implements Trampoline<T> {
    final Supplier<Trampoline<T>> f;

    More(Supplier<Trampoline<T>> f) {
        this.f = f;
    }
}

mostly boiler plate, we have a discriminated union Trampoline with cases Done and More. run() is the interesting thing there We start with this and keep walking down until we have a value Here's a bit of a contrived example but it shows well how it works:

static boolean even(int n) {
    if (n == 0) return true;
    return odd(n - 1);
}

static boolean odd(int n) {
    if (n == 0) return false;
    return even(n - 1);
}

A bad way to determine if a number is even or odd. Pretty straightforward. These functions are mutually recursive but even with tail call optimization, we couldnt directly unroll them into a loop

but the important part is this:

odd(100_000)
Exception in thread "main" java.lang.StackOverflowError
    ...

asking for 100000 stack frames isn't reasonable. But asking to allocate 100000 objects on the heap? Java's pretty good at that So we transform them to use trampoline:

static boolean even(int n) {
    return evenTramp(n).run();
}

static boolean odd(int n) {
    return oddTramp(n).run();
}

private static Trampoline<Boolean> evenTramp(int n) {
    if (n == 0) return new Done<>(true);
    return new More<>(() -> oddTramp(n - 1));
}

private static Trampoline<Boolean> oddTramp(int n) {
    if (n == 0) return new Done<>(false);
    return new More<>(() -> evenTramp(n - 1));
}

we introduced two helper functions so we can keep the original signatures of even and odd. Those two helpers are the trampolined functions. In both, there isnt any recursion going on evenTramp and oddTramp are not mutually recursive Both of them immediately return either Done or More. so before, calling even gave us a callstack like even(odd(even(odd(..... 100000 times. But now, calling evenTramp just gets us a single More instance with a function which describes how to get the next one. That function can be executed later (execution is "deferred" or "suspended") in fact, it's executed in run() which clearly only requires one stack frame since it's just a loop so the stack is saved! we can build on this technique to create a stack safe flatMap for any monad But... that gets pretty complicated because the "best" way to do it is to realize that Trampoline is actually a Free monad over () => A and any trampolined monad is actually just that monad transformed with Trampoline anyways the point is that Mono does none of this so it's not stack safe 😄

I wasnt satisfied with that bit about flatmap and took a dive (http://blog.higher-order.com/assets/trampolines.pdf)...

We introduce a new case in Trampoline which captures what flatMap does in an object:

static class FlatMap<T, R> implements Trampoline<R> {
    final Trampoline<T> subroutine;
    final Function<T, Trampoline<R>> continuation;

    FlatMap(Trampoline<T> subroutine, Function<T, Trampoline<R>> continuation) {
        this.subroutine = subroutine;
        this.continuation = continuation;
    }
}

Naming is due to the relationship with continuation passing style (it's more related to coroutines than I originally understood @dannie). The result of the deferred subroutine is passed to the continuiation. Notice that is exactly how flatMap works. That means that Trampoline#flatMap() is simply:

default <R> Trampoline<R> flatMap(Function<T, Trampoline<R>> f) {
    return new FlatMap<>(this, f);
}

However, Trampoline#run() becomes much more complicated. We essentially implement our own system of stack frames. But remember that this is all still hapening on the heap. That's all we care about.

default T run() {
    Trampoline<?> trampoline = this;
    Stack<Function<Object, Trampoline<T>>> stack = new Stack<>();
    T result = null;

    while (result == null) {
        if (trampoline instanceof Done) {
            if (stack.isEmpty()) {
                result = (T) ((Done<?>) trampoline).result;
            } else {
                trampoline = stack.pop().apply(((Done<?>) trampoline).result);
            }
        } else if (trampoline instanceof More) {
            trampoline = ((More<?>) trampoline).f.get();
        } else if (trampoline instanceof FlatMap) {
            stack.push(((FlatMap) trampoline).continuation);
            trampoline = ((FlatMap) trampoline).subroutine;
        }
    }

    return result;
}

We can think of this function as an interpreter for a binary tree of Trampoline objects with a depth first search. Done is a leaf, More is a node with exactly one direct child, and FlatMap is a node with exactly 2 direct children. The stack stores the previously-encountered continuations. When a Done is encountered, if there remains continuations on the stack, we pass the result to it. If not, the result we have is the final result of the function. When More is encountered, we simple continue to traverse down the branch. When FlatMap is encountered, we traverse its left branch (that of the subroutine) and when that branch eventually terminates we pass the branch's result to the continuation (as described by the behavior of encountering Done) which is the right branch. With all of that in place, we can implement a stack safe version of flatMap for any monad by defering to Trampoline#flatMap(). Do note that I previously misunderstood and thought that we could make the existing flatMap stack safe, however I don't think this is the case. Let's do it for a standard definition of Option:

interface Option<T> {
    static <T> Option<T> pure(T value) {
        return new Some<>(value);
    }

    static <T> Option<T> none() {
        return (Option<T>) None.INSTANCE;
    }

    default <R> Option<R> flatMap(Function<T, Option<R>> f) {
        if (this instanceof Some) {
            return f.apply(((Some<T>) this).value);
        }
        return none();
    }
}
static class Some<T> implements Option<T> {
    final T value;

    private Some(T value) {
        this.value = value;
    }
}
static class None<T> implements Option<T> {
    static final None<?> INSTANCE = new None<>();
    private None() {}
}

again, mostly boilerplate. Notice that Option#flatMap() is exactly as you would expect it to be and is not stack safe. We can demonstrate that with a silly example:

static Option<Integer> foo(Option<Integer> opt) {
    return opt.flatMap(x -> {
        if (x == 0) return Option.pure(0);
        return foo(Option.pure(x - 1));
    });
}

foo(Option.pure(100_000))
Exception in thread "main" java.lang.StackOverflowError
    ...

We introduce a different version which is stack safe:

default <R> Trampoline<Option<R>> trampFlatMap(Function<T, Trampoline<Option<R>>> f) {
    return new Done<>(this).flatMap(opt -> {
        if (opt instanceof Some) return f.apply(((Some<T>) opt).value);
        return new Done<>(none());
    });
}

We can visualize this as a branch in the tree whose left branch is immediately terminating (just results in this, the Option) and whose right branch is either none() or a branch arbitrarily determined by the function f. This gives control over to the caller of trampFlatMap. It's also important that all this functio can do is build that tree. It doesnt make any recursive calls itself.

We can now implement a stack safe version of the silly example:

static Trampoline<Option<Integer>> fooTramp(Trampoline<Option<Integer>> trampoline) {
    return trampoline.flatMap(opt -> {
        return opt.trampFlatMap(x -> {
            if (x == 0) return new Done<>(Option.pure(0));
            return fooTramp(new Done<>(Option.pure(x - 1)));
        });
    });
}

static Option<Integer> foo(Option<Integer> opt) {
    return fooTramp(new Done<>(opt)).run();
}

The function we passed to opt.trampFlatMap only returns Done or FlatMap (indrectly, because a call to fooTramp results in a call to Trampoline#flatMap() which produces a FlatMap) so we can visualize the tree we're building like so: Graph Our interpreter will walk this tree until the final leaf at 0.

And as a final note, as I mentioned earlier, we actually reimplemented some stuff we could've gotten for free in a language like Scala. The whole thing (trampoline) is actually a specialization of the Free monad. And the trampFlatMap we implemented is really the flatMap for the monad transformer of Option over Trampoline. In other words,

type Trampoline[A] = Free[Function0]

and we get a stack safe Option for free just with OptionT[Trampoline, A]

And only now is the point that Mono does none of this so it's not stack safe 😄 😄 :D Thank you for coming to my ted talk

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment