Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active December 1, 2024 22:32
Show Gist options
  • Save VictorTaelin/fb798a5bd182f8c57dd302380f69777a to your computer and use it in GitHub Desktop.
Save VictorTaelin/fb798a5bd182f8c57dd302380f69777a to your computer and use it in GitHub Desktop.
Optimal Linear Context Passing

Optimal Linear Context Passing

In functional algorithms, it is common to pass a context down recursive calls. In Haskell-like languages, this is done like:

foo (App fun arg) ctx = App (foo fun ctx) (foo arg ctx)

This inoffensive-looking line can have harsh performance implications, specially in situations where linearity is desirable, since the 'ctx' variable is used twice. Fortunatelly, there is a workaround: pass the context "monadically".

foo (App fun arg) ctx0 = case foo fun ctx0 of
  (ctx1, fun') -> case foo arg ctx1 of
    (ctx2, arg') -> (ctx2, App fun' arg')

This successfully linearizes the context, resulting in major potential speedups on linear/optimal evaluators. Sadly, this implementation loses the laziness we expect from Haskell, since we accidentally sequentialize the entire call. Thus:

main = isApp (snd (foo (App a b)))

Should immediately return "True" in O(1), but doesn't. Instead, it is O(n) on the size of the term we pass to foo, because Haskell has to "touch" all the leaves before emitting the head constructor. To solve that, it suffices to destruct in parallel instead:

foo (App fun arg) ctx0 =
  let (ctx1, fun') = foo fun ctx0 in
  let (ctx2, arg') = foo arg ctx1 in
  (ctx2, App fun' arg')

This will let main normalize in O(1), as expected. Sadly, we can't always write like that. For example, if we use an actual monad with the do-notation, the sequentialization is inevitable. That is the case if the state isn't a simple pair, but, for example, the Maybe Monad, an IO, some stateful computation, and so on. In these cases, we can't avoid this complexity blowup in Haskell.

General approach: "pure mutable references"

For a more genreal option, HVM has a clever trick that allows us to pass a context around linearly, even in a hairy monadic context. First, we observe that we can exploit globally-scoped lambdas to emulate mutable references, in a way that preserves purity (!?). To quote T6's on HOC's Discord:

you can represent a mutable reference by @$final_value initial_value

// ref.mut: &T -> (T -> T) -> ()
ref.mut ref fn = let $new = (fn (ref $new));

// ref.split: &T -> (&T, &T)
ref.split ref = let x = (ref $z); (@$y x, @$z $y)

// ref.drop: &T -> ()
ref.drop ref = let $x = (ref $x); *

basically, with the scopeless lambda implementation of mutable references, you pass the mutable reference the value you want to write to it, and it returns the old value. It's like its own std::mem::replace function. So then:

  • to implement ref.drop, we can write let $x = (ref $x), which sets the reference to its current value, getting rid of the mutable reference
  • to implement ref.mut, we can write:
let $new = (fn $old); // the new value will be the callback applied to the old value
let $old = (ref $new); // write $new into the mutable reference, getting out $new

(the codeblock I posted inlined $old, so it was just let $new = (fn (ref $new)), but they're equivalent) to implement ref.split (which takes a mutable reference and splits it into two, we will have three values:

  • v0 is the initial value of the mutable reference
  • v1 v1 is the intermediate value, after being mutated by the first mutable reference
  • v2 is the final value, after being mutated by both in code:
ref.split ref =
  let $v0 = (ref $v2);
  (@$v1 $v0, @$v2 $v1)

This is really handy, and can be expressed on HVM3's new low-level syntax as:

@mut(ref fn) = !! $new = (fn (ref $new)) *
@spt(ref fn) = (fn λ$y(ref $z) λ$z($y))

@main =
  ! $X = λ$x(0) // u32* X = &0;
  !! @spt($X λ$X0 λ$X1 *) // u32* X0 = X; u32* X1 = X;
  !! @mut($X0 λx(+ x 1)) // *X += 1;
  !! @mut($X1 λx(+ x 1)) // *X += 1;
  $x // *X

// The '!! x = val' notation represents a seq operator.
// It reduces 'val' to whnf and assigns the result to 'x'.
// The '!! val' notation is a shortcut for '!! _ = val'.
// The '$var' notation is for globally scoped variables.

This allows us to actually return #Node directly, without ever involving pairs. Here are all 3 versions on HVM3's new low-level syntax:

// Optimal recursive context passing with HVM's "pure mutable references"
// Article: https://gist.github.com/VictorTaelin/fb798a5bd182f8c57dd302380f69777a

data Pair { #Pair{fst snd} }
data List { #Nil #Cons{head tail} }
data Tree { #Leaf #Node{lft rgt} }

// Utils
// -----

@is_node(tree) = ~tree {
  #Leaf: 0
  #Node{lft rgt}: 1
}

@range(n r) = ~n !r {
  0: r
  p: !&0{p0 p1}=p @range(p0 #Cons{p1 r})
}

@fst(p) = ~p {
  #Pair{fst snd}: fst
}

@snd(p) = ~p {
  #Pair{fst snd}: snd
}

@tm0(sup) = !&0{tm0 tm1}=sup tm0
@tm1(sup) = !&0{tm0 tm1}=sup tm1

// Mutable references
@mut(ref fn) = !! $new = (fn (ref $new)) *
@spt(ref fn) = (fn λ$y(ref $z) λ$z($y))

// Slow Version
// ------------

// The slow version passes a context monadically, with a pair state.
@list_to_tree_slow(n ctx) = ~n !ctx {
  // Base Case:
  // - take the ctx's head
  // - return the context's tail and '#Leaf{head}'
  0: ~ctx {
    #Nil: *
    #Cons{head tail}: #Pair{tail #Leaf{head}}
  }
  // Step Case:
  // - recurse to the lft, get the new ctx and 'lft' tree
  // - recurse to the rgt, get the new ctx and 'rgt' tree
  // - return the final context and a '#Node{lft rgt}'
  p:
    !&0{p0 p1}=p
    ~ @list_to_tree_slow(p0 ctx) {
      #Pair{ctx lft}: ~ @list_to_tree_slow(p1 ctx) {
        #Pair{ctx rgt}: #Pair{ctx #Node{lft rgt}}
      }
    }
}

// Fast Version: parallel destructing
// ----------------------------------

// This version uses a superposition instead of a pair. It is faster because it
// allows us to destruct in parallel (which isn't available for native ADTs),
// preventing the sequential chaining issue.
@list_to_tree_fast_par(n ctx) = ~n !ctx {
  0: ~ctx {
    #Nil: *
    #Cons{head tail}: &0{tail #Leaf{head}}
  }
  p:
    ! &0{p0 p1}   = p
    ! &0{ctx lft} = @list_to_tree_fast_par(p0 ctx)
    ! &0{ctx rgt} = @list_to_tree_fast_par(p1 ctx)
    &0{ctx #Node{lft rgt}}
}


// Fast Version: mutable references
// --------------------------------

// This version passes the context as a mutable reference.
// It avoids pair entirely.
@list_to_tree_fast_mut(n ctx) = ~n !ctx {
  // Base case:
  // - mutably replace the context by its tail, and extract its head
  // - return just '#Leaf{head}' (no pairs!)
  0: 
    !! @mut(ctx λctx ~ctx { #Nil:* #Cons{$head tail}:tail })
    #Leaf{$head}
  // Step Case:
  // - split the mutable reference into two
  // - recurse to the lft and rgt, passing the split mut refs
  // - return just '#Node{lft rgt}' directly (no pairs!)
  p:
    !&0{pL pR}=p
    !! @spt(ctx λ$ctxL λ$ctxR *)
    #Node{
      @list_to_tree_fast_mut(pL $ctxL)
      @list_to_tree_fast_mut(pR $ctxR)
    }
}

// Main
// ----

// Tree Depth
@depth = 16

// Tests slow version
//@main = @is_node(@snd(@list_to_tree_slow(@depth (@range((<< 1 @depth) 0)))))

// Tests fast version with parallel destruct
//@main = @is_node(@tm1(@list_to_tree_fast_par(@depth (@range((<< 1 @depth) 0)))))

// Tests fast version with mutable refs
//@main = @is_node(@list_to_tree_fast_mut(@depth λ$ctx(@range((<< 1 @depth) 0))))

All 3 versions are available on HVM3's repo book/ directory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment