Skip to content

Instantly share code, notes, and snippets.

@Saizan
Created March 20, 2021 10:22
Show Gist options
  • Save Saizan/04087ed2edfdfb9f917e46c82f303ec0 to your computer and use it in GitHub Desktop.
Save Saizan/04087ed2edfdfb9f917e46c82f303ec0 to your computer and use it in GitHub Desktop.
F# trampolining monad
module tramp
// we want to use an existential type, but F# makes that complicated so obj it is.
type Tree = | Bind of Tree * (obj -> Tree)
| Delay of (unit -> Tree)
| Leaf of obj
type FnStack<'a,'b> = | End of ('a -> 'b)
| Cons of ('a -> Tree) * FnStack<obj,'b>
// this is the tail recursive loop that actually executes our calls.
let rec eval' : Tree -> (obj -> 'b) -> 'b = fun m s ->
match m with
| Leaf x -> s x
| Bind (m,f) -> eval' m (fun o -> eval' (f o) s)
| Delay f -> eval' (f ()) s
// same thing here but with an explicit stack of continuations.
// not sure which one is more efficient.
let rec eval : Tree -> FnStack<obj,'b> -> 'b = fun m s ->
match m with
| Leaf f ->
match s with
| End g -> g f
| Cons (g,s) -> eval (g f) s
| Bind (m,f) -> eval m (Cons (f,s))
| Delay f -> eval (f ()) s
// We wrap Tree into a parametrized type that will be our monad and provide type safety
// if user code does not have access to the implementation.
type Eval<'a> = E of Tree
// this function breaks type safety
let unEval (E t) = t
let ret (a : 'a) : Eval<'a> = E (Leaf (upcast a : obj))
// if left recursing, the left argument should be a Delay.
let (>>=) (E t : Eval<'a>) (f : 'a -> Eval<'b>) : Eval<'b> =
E (Bind (t,fun x -> unEval (f (downcast x))))
let delay (m : unit -> Eval<'a>) : Eval<'a> =
E (Delay (fun () -> unEval (m ())))
let run (E t : Eval<'a>) : 'a = eval' t (fun x -> downcast x)
type TrampolineBuilder() =
member this.Bind(m,f) = m >>= f
member this.Delay f = delay f
member this.Run m = m
member this.Return m = ret m
let tramp = TrampolineBuilder()
/// Tests
// without using delay
let rec leftLoop0 n =
if n <= 0 then ret 0
else leftLoop0 (n-1) >>= fun m -> ret (m+1)
// this still stack overflows.
let test0 () = run (leftLoop0 1000000)
// with delay
let rec leftLoop n =
if n <= 0 then ret 0
else delay (fun () -> leftLoop (n-1)) >>= fun m -> ret (m+1)
// runs without a stack overflow.
let test () = run (leftLoop 1000000)
// using computation expressions.
// seems like the left argument of a bind is wrapped in Delay/Run calls too,
// otherwise this wouldn't work.
let rec leftLoop1 n =
tramp {
if n <= 0
then return 0
else let! m = leftLoop1 (n-1)
return m+1
}
// also runs without a stack overflow.
let test1 () = run (leftLoop1 1000000)
// direct left recursion
let rec leftLoopBad n =
if n <= 0 then 0 else leftLoopBad (n-1) |> fun m -> m+1
// stack overflows
let test2 () = leftLoopBad 1000000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment