Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active September 4, 2024 10:22
Show Gist options
  • Save VictorTaelin/7fe49a99ebca42e5721aa1a3bb32e278 to your computer and use it in GitHub Desktop.
Save VictorTaelin/7fe49a99ebca42e5721aa1a3bb32e278 to your computer and use it in GitHub Desktop.
Accelerating Discrete Program Search with SUP Nodes

Accelerating Discrete Program Search

I am investigating how to use Bend (a parallel language) to accelerate Symbolic AI; in special, Discrete Program Search. Basically, think of it as an alternative to LLMs, GPTs, NNs, that is also capable of generating code, but by entirely different means. This kind of approach was never scaled with mass compute before - it wasn't possible! - but Bend changes this. So, my idea was to do it, and see where it goes.

Now, while I was implementing some candidate algorithms on Bend, I realized that, rather than mass parallelism, I could use an entirely different mechanism to speed things up: SUP Nodes. Basically, it is a feature that Bend inherited from its underlying model ("Interaction Combinators") that, in simple terms, allows us to combine multiple functions into a single superposed one, and apply them all to an argument "at the same time". In short, it allows us to call N functions at a fraction of the expected cost. Or, in simple terms: why parallelize when we can share?

As you can already imagine, this could be extremely important for Symbolic AI algorithms like Discrete Program Search (DPS), one of the top solutions to ARC-AGI (a memorization-resistant reasoning benchmark on which LLMs struggle). That's because DPS works by enumerating a massive amount of candidate programs and trying each one on the test cases, until one succeeds. Obviously, this approach is dumb, exponential and computationally prohibitive, which is why ARC-AGI was created to begin with: to motivate better (NN-based) algorithms.

But is brute-force really that bad? While there is a massive amount of terms to explore, these terms aren't random; they're structured, redundant and highly repetitive. This suggests some mechanism to optimally share intermediate computations could greatly improve the search. And that's exactly what SUP Nodes do! This sounded so compelling that spent the last few weeks trying to come up with the right way to do it. A few days ago, I had a huge breakthrough, but there was still one last issue: how to collapse the superpositions back into a single result, without generating more work? Today, I just realized how, and all the pieces fell into place.

As the result, I can now search programs that would take trillions of interactions in less than a million interactions. To validate it, I implemented the exact same algorithms (a very simple enumerator) on Haskell, as well as I could. Using the Omega Monad, Haskell takes about 2.8s second to find a sample XOR-XNOR function, while HVM takes 0.0085s. So, as far as I can tell, this does seem like a speedup. I'm publishing below the Haskell code as a sanity check. Am I doing something wrong? Without changing the algorithm (i.e., keeping it a brute-force search), can we bring this time down?

Haskell Algorithm (slow, via Omega Monad)

-- A demo, minimal Program Search in Haskell

-- Given a test (input/output) pairs, it will find a function that passes it.
-- This file is for demo purposes, so, it is restricted to just simple, single
-- pass recursive functions. The idea is to use HVM superpositions to try many
-- functions "at once". Obviously, Haskell does not have them, so, we just use
-- the Omega Monad to convert to a list of functions, and try each separately.

import Control.Monad (forM_)

-- PRELUDE
----------

newtype Omega a = Omega { runOmega :: [a] }

instance Functor Omega where
  fmap f (Omega xs) = Omega (map f xs)

instance Applicative Omega where
  pure x                = Omega [x]
  Omega fs <*> Omega xs = Omega [f x | f <- fs, x <- xs]

instance Monad Omega where
  Omega xs >>= f = Omega $ diagonal $ map (\x -> runOmega (f x)) xs

diagonal :: [[a]] -> [a]
diagonal xs = concat (stripe xs) where
  stripe []             = []
  stripe ([]     : xss) = stripe xss
  stripe ((x:xs) : xss) = [x] : zipCons xs (stripe xss)
  zipCons []     ys     = ys
  zipCons xs     []     = map (:[]) xs
  zipCons (x:xs) (y:ys) = (x:y) : zipCons xs ys

-- ENUMERATOR
-------------

-- A bit-string
data Bin
  = O Bin
  | I Bin
  | E

-- A simple DSL for `Bin -> Bin` terms
data Term
  = MkO Term      -- emits the bit 0
  | MkI Term      -- emits the bit 1
  | Mat Term Term -- pattern-matches on the argument
  | Rec           -- recurses on the argument
  | Ret           -- returns the argument
  | Sup Term Term -- a superposition of two functions

-- Checks if two Bins are equal
bin_eq :: Bin -> Bin -> Bool
bin_eq (O xs) (O ys) = bin_eq xs ys
bin_eq (I xs) (I ys) = bin_eq xs ys
bin_eq E      E      = True
bin_eq _      _      = False

-- Stringifies a Bin
bin_show :: Bin -> String
bin_show (O xs) = "O" ++ bin_show xs
bin_show (I xs) = "I" ++ bin_show xs
bin_show E      = "E"

-- Checks if two term are equal
term_eq :: Term -> Term -> Bool
term_eq (Mat l0 r0) (Mat l1 r1) = term_eq l0 l1 && term_eq r0 r1
term_eq (MkO t0)    (MkO t1)    = term_eq t0 t1
term_eq (MkI t0)    (MkI t1)    = term_eq t0 t1
term_eq Rec         Rec         = True
term_eq Ret         Ret         = True
term_eq _           _           = False

-- Stringifies a term
term_show :: Term -> String
term_show (MkO t)    = "(O " ++ term_show t ++ ")"
term_show (MkI t)    = "(I " ++ term_show t ++ ")"
term_show (Mat l r)  = "{O:" ++ term_show l ++ "|I:" ++ term_show r ++ "}"
term_show (Sup a b)  = "{" ++ term_show a ++ "|" ++ term_show b ++ "}"
term_show Rec        = "@"
term_show Ret        = "*"

-- Enumerates all terms
enum :: Bool -> Term
enum s = (if s then Sup Rec else id) $ Sup Ret $ Sup (intr s) (elim s) where
  intr s = Sup (MkO (enum s)) (MkI (enum s))
  elim s = Mat (enum True) (enum True)

-- Converts a Term into a native function
make :: Term -> (Bin -> Bin) -> Bin -> Bin
make Ret       _ x = x
make Rec       f x = f x
make (MkO trm) f x = O (make trm f x)
make (MkI trm) f x = I (make trm f x)
make (Mat l r) f x = case x of
  O xs -> make l f xs
  I xs -> make r f xs
  E    -> E

-- Finds a program that satisfies a test
search :: Int -> (Term -> Bool) -> [Term] -> IO ()
search n test (tm:tms) = do
  if test tm then
    putStrLn $ "FOUND " ++ term_show tm ++ " (after " ++ show n ++ " guesses)"
  else
    search (n+1) test tms

-- Collapses a superposed term to a list of terms, diagonalizing
collapse :: Term -> Omega Term
collapse (MkO t) = do
  t' <- collapse t
  return $ MkO t'
collapse (MkI t) = do
  t' <- collapse t
  return $ MkI t'
collapse (Mat l r) = do
  l' <- collapse l
  r' <- collapse r
  return $ Mat l' r'
collapse (Sup a b) =
  let a' = runOmega (collapse a) in
  let b' = runOmega (collapse b) in
  Omega (diagonal [a',b'])
collapse Rec = return Rec
collapse Ret = return Ret

-- Some test cases:
-- ----------------

test_not :: Term -> Bool
test_not tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (O (I (I (I (O (I (I E))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (O (O (O (I (I (O (O (O E))))))))
  e1 = (bin_eq (fn x1) y1)

test_inc :: Term -> Bool
test_inc tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (I (O (O (O (I (O (O E))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (O (O (O (I (O (I (I (I E))))))))
  e1 = (bin_eq (fn x1) y1)

test_mix :: Term -> Bool
test_mix tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (O (I (I (I (O (I (O (I (O (I (I (I (O (I (O E))))))))))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (I (I (I (I (I (I (I (O (I (O (I (I (I (I (I (I E))))))))))))))))
  e1 = (bin_eq (fn x1) y1)

test_xors :: Term -> Bool
test_xors tm = e0 && e1 where
  fn = make tm fn
  x0 = (I (I (O (O (O (I (O (O E))))))))
  y0 = (I (I (O (I E))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (O (O (I (I (I (O (I E))))))))
  y1 = (O (O (I (O E))))
  e1 = (bin_eq (fn x1) y1)

test_xor_xnor :: Term -> Bool
test_xor_xnor tm = e0 && e1 where
  fn = make tm fn
  x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
  y0 = (I (O (I (O (I (O (O (I (I (O E))))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
  y1 = (I (O (O (I (I (O (I (O (O (I E)))))))))) 
  e1 = (bin_eq (fn x1) y1)

main :: IO ()
main = search 0 test_xor_xnor $ runOmega $ collapse $ enum False

HVM Algorithm (fast, with SUP Nodes)

NOTE: to run this, you must use the dup_labels branch of HVM1. We'll soon adapt Bend to work for this too, but it is currently lacking lazy mode, 64-bit dup labels and SUP-node quoting.

// HVM Prelude
// -----------

(Fix f) = (f (Fix f))

(U60.if 0 t f) = f
(U60.if x t f) = t

(U60.show n) = (U60.show.go n "")

(U60.show.go n x) = (λx(U60.if (> n 9) (U60.show.go (/ n 10) x) x) (String.cons (+ 48 (% n 10)) x))

(U60.seq 0 cont) = (cont 0)
(U60.seq n cont) = (cont n)

(And 0 x) = 0
(And 1 x) = x

(Or 0 x) = x
(Or 1 x) = 1

(List.get (List.nil) _)       = (Err "out-of-bounds")
(List.get (List.cons x xs) 0) = x
(List.get (List.cons x xs) n) = (List.get xs (- n 1))

(List.map f List.nil)         = List.nil
(List.map f (List.cons x xs)) = (List.cons (f x) (List.map f xs))

(List.imap f List.nil)         = List.nil
(List.imap f (List.cons x xs)) = (List.cons (f 0 x) (List.imap λiλx(f (+ i 1) x) xs))

(List.concat (List.nil)       ys) = ys
(List.concat (List.cons x xs) ys) = (List.cons x (List.concat xs ys))

(List.flatten List.nil)         = List.nil
(List.flatten (List.cons x xs)) = (List.concat x (List.flatten xs))

(List.length List.nil)         = 0
(List.length (List.cons x xs)) = (+ 1 (List.length xs))

(List.take 0 xs)               = List.nil
(List.take n List.nil)         = List.nil
(List.take n (List.cons x xs)) = (List.cons x (List.take (- n 1) xs))

(List.head (List.cons x xs)) = x
(List.tail (List.cons x xs)) = xs

(List.push x List.nil)         = (List.cons x List.nil)
(List.push x (List.cons y ys)) = (List.cons y (List.push x ys))

(List.diagonal xs) = (List.flatten (List.stripe xs))

(List.stripe List.nil)                         = []
(List.stripe (List.cons List.nil         xss)) = (List.stripe xss)
(List.stripe (List.cons (List.cons x xs) xss)) = (List.cons [x] (List.zip_cons xs (List.stripe xss)))

(List.zip_cons []               ys)               = ys
(List.zip_cons xs               [])               = (List.map λk(List.cons k []) xs)
(List.zip_cons (List.cons x xs) (List.cons y ys)) = (List.cons (List.cons x y) (List.zip_cons xs ys))

(Omega.pure x)    = [x]
(Omega.bind xs f) = (List.diagonal (List.map f xs))

(String.concat String.nil         ys) = ys
(String.concat (String.cons x xs) ys) = (String.cons x (String.concat xs ys))

(String.join List.nil)         = String.nil
(String.join (List.cons x xs)) = (String.concat x (String.join xs))

(String.seq (String.cons x xs) cont) = (U60.seq x λx(String.seq xs λxs(cont (String.cons x xs))))
(String.seq String.nil         cont) = (cont String.nil)

(String.eq String.nil         String.nil)         = 1
(String.eq (String.cons x xs) (String.cons y ys)) = (And (== x y) (String.eq xs ys))
(String.eq xs                 ys)                 = 0

(String.take 0 xs)                 = String.nil
(String.take n String.nil)         = String.nil
(String.take n (String.cons x xs)) = (String.cons x (String.take (- n 1) xs))

(Tup2.match (Tup2.new fst snd) fn) = (fn fst snd)
(Tup2.fst (Tup2.new fst snd))      = fst
(Tup2.snd (Tup2.new fst snd))      = snd

(Join None     b) = b
(Join (Some x) b) = (Some x)

(Print []  value) = value
(Print msg value) = (String.seq (String.join msg) λstr(HVM.log str value))

// Priority Queue
// data PQ = Empty | Node U60 U60 PQ PQ

// PQ.new: Create a new empty Priority Queue
(PQ.new) = Empty

// PQ.put: Add a new (key, val) pair to the Priority Queue
(PQ.put key val Empty)              = (Node key val Empty Empty)
(PQ.put key val (Node k v lft rgt)) = (PQ.put.aux (< key k) key val k v lft rgt)

(PQ.put.aux 1 key val k v lft rgt) = (Node key val (Node k v lft rgt) Empty)
(PQ.put.aux 0 key val k v lft rgt) = (Node k v (PQ.put key val lft) rgt)

// PQ.get: Get the smallest element and return it with the updated queue
(PQ.get Empty)              = (HVM.LOG ERR 0)
(PQ.get (Node k v lft rgt)) = λs(s k v (PQ.merge lft rgt))

// Helper function to merge two priority queues
(PQ.merge Empty              rgt)                = rgt
(PQ.merge lft                Empty)              = lft
(PQ.merge (Node k1 v1 l1 r1) (Node k2 v2 l2 r2)) = (PQ.merge.aux (< k1 k2) k1 v1 l1 r1 k2 v2 l2 r2)

(PQ.merge.aux 1 k1 v1 l1 r1 k2 v2 l2 r2) = (Node k1 v1 (PQ.merge r1 (Node k2 v2 l2 r2)) l1)
(PQ.merge.aux 0 k1 v1 l1 r1 k2 v2 l2 r2) = (Node k2 v2 (PQ.merge (Node k1 v1 l1 r1) r2) l2)

// Collapser
(Collapse (HVM.SUP k a b) pq) = (Collapse None (PQ.put k a (PQ.put k b pq)))
(Collapse (Some x)        pq) = x
(Collapse None            pq) = ((PQ.get pq) λkλxλpq(Collapse x pq))

// Bin Enumerator
// --------------

(O xs) = λo λi λe (o xs)
(I xs) = λo λi λe (i xs)
E      = λo λi λe e

(Bin.eq xs ys) = (xs
  λxsp λys (ys λysp(Bin.eq xsp ysp) λysp(0) 0) 
  λxsp λys (ys λysp(0) λysp(Bin.eq xsp ysp) 0) 
  λys (ys λysp(0) λysp(0) 1)
  ys)

(Term.eq (Mat l0 r0) (Mat l1 r1)) = (And (Term.eq l0 l1) (Term.eq r0 r1))
(Term.eq (MkO t0)    (MkO t1))    = (Term.eq t0 t1)
(Term.eq (MkI t0)    (MkI t1))    = (Term.eq t0 t1)
(Term.eq Rec         Rec)         = 1
(Term.eq Ret         Ret)         = 1
(Term.eq _           _)           = 0

Zero = (O Zero)
Neg1 = (I Neg1)

(L0 x) = (+ (* x 2) 0)
(L1 x) = (+ (* x 2) 1)

(ENUM lab s) =
  let lA = (+ (* lab 4) 0)
  let lB = (+ (* lab 4) 1)
  let lC = (+ (* lab 4) 2)
  let rc = (U60.if s λx(HVM.SUP lB Rec x) λx(x))
  let rt = λx(HVM.SUP lC Ret x)
  (rt (rc (HVM.SUP lA
    (INTR (L0 lab) s)
    (ELIM (L1 lab) s))))
(INTR lab s) =
  let lA = (+ (* lab 4) 3)
  (HVM.SUP lA
    (MkO (ENUM (L0 lab) s))
    (MkI (ENUM (L1 lab) s)))
(ELIM lab s) =
  (Mat (ENUM (L0 lab) 1)
       (ENUM (L1 lab) 1))

(Make Ret      ) = λfλx(x)
(Make Rec      ) = λfλx(f x)
(Make (MkO trm)) = λfλx(O ((Make trm) f x))
(Make (MkI trm)) = λfλx(I ((Make trm) f x))
(Make (Mat l r)) = λfλx(x λx((Make l) f x) λx((Make r) f x) (E))

(Bin.show xs) = (xs λxs(String.join ["O" (Bin.show xs)]) λxs(String.join ["I" (Bin.show xs)]) "E")
(Bin.view xs) = (xs λxs(B0 (Bin.view xs)) λxs(B1 (Bin.view xs)) BE)

(COL (HVM.SUP k a b)) = (Join (COL a) (COL b))
(COL x)               = x

(Flat (HVM.SUP k a b)) = (List.diagonal [(Flat a) (Flat b)])
(Flat Ret)             = (Omega.pure Ret)
(Flat Rec)             = (Omega.pure Rec)
(Flat (MkO trm))       = (Omega.bind (Flat trm) λtrm(Omega.pure (MkO trm)))
(Flat (MkI trm))       = (Omega.bind (Flat trm) λtrm(Omega.pure (MkI trm)))
(Flat (Mat l r))       = (Omega.bind (Flat l) λl(Omega.bind (Flat r)λr(Omega.pure (Mat l r))))

// TODO: implement a Term.show function
// Term.show function implementation
(Term.show Ret)              = "Ret"
(Term.show Rec)              = "Rec"
(Term.show (MkO term))       = (String.join ["(O " (Term.show term) ")"])
(Term.show (MkI term))       = (String.join ["(I " (Term.show term) ")"])
(Term.show (Mat left right)) = (String.join ["{" (Term.show left) "|" (Term.show right) "}"])

(Test_same g cd fn) =
  let x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
  let y0 = (g x0)
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
  let y1 = (g x1)
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

(Test_not cd fn) =
  let x0 = (O (I (O (O (O (I (O (O E))))))))
  let y0 = (I (O (I (I (I (O (I (I E))))))))
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (I (I (I (O (O (I (I (I E))))))))
  let y1 = (O (O (O (I (I (O (O (O E))))))))
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

(Test_inc cd fn) =
  let x0 = (O (I (O (O (O (I (O (O E))))))))
  let y0 = (I (I (O (O (O (I (O (O E))))))))
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (I (I (I (O (O (I (I (I E))))))))
  let y1 = (O (O (O (I (O (I (I (I E))))))))
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

(Test_mix cd fn) =
  let x0 = (O (I (O (O (O (I (O (O E))))))))
  let y0 = (I (O (I (I (I (O (I (O (I (O (I (I (I (O (I (O E))))))))))))))))
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (I (I (I (O (O (I (I (I E))))))))
  let y1 = (I (I (I (I (I (I (I (O (I (O (I (I (I (I (I (I E))))))))))))))))
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

(Test_xors cd fn) =
  let x0 = (I (I (O (O (O (I (O (O E))))))))
  let y0 = (I (I (O (I E))))
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (I (O (O (I (I (I (O (I E))))))))
  let y1 = (O (O (I (O E))))
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

(Test_xor_xnor cd fn) =
  let x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
  let y0 = (I (O (I (O (I (O (O (I (I (O E))))))))))
  let e0 = (Bin.eq (fn x0) y0)
  let x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
  let y1 = (I (O (O (I (I (O (I (O (O (I E)))))))))) 
  let e1 = (Bin.eq (fn x1) y1)
  (U60.if (And e0 e1) (Some cd) None)

Main =
  //let term = (Mat (Mat (MkO (MkI (Rec))) (MkI (MkO (Rec)))) (Mat (MkI (MkO (Rec))) (MkO (MkI (Rec)))))
  //let bits = (O (I (O (O (O (I (O (I (O (O E))))))))))
  //let func = (Fix (Make term))
  //(Bin.show (func bits))
  //(Test_xor_xnor term (Fix (Make term)))

  //let term = (Mat (Mat (MkI (MkO (Rec))) (MkO (MkI (Rec)))) (Mat (MkO (MkI (Rec))) (MkI (MkO (Rec)))))
  //let func = (Fix (Make term))

  let terms = (ENUM 1 0)
  let funcs = (Fix (Make terms))
  let found = (Test_xor_xnor terms funcs)
  (Collapse found PQ.new)

Edit: HVM source published. Announcement on https://x.com/VictorTaelin/status/1829143659440144493

@tankorsmash
Copy link

fff7d700df6dd39718a6e0e12252f611e7da14fc5ff8d363dc36ee6b64f9786a

Is this a git commit somewhere? I'm not sure how to read this

@LiamGoodacre
Copy link

LiamGoodacre commented Aug 8, 2024

Bit of a speed up on my laptop via some further unrolling / specialising / "DList" / etc.
( & I'm assuming you're already using -O2. )

Your original Applicative Omega instance currently isn't diagonal;
so using <*> for Mat makes it take forever.

I think your stripe had a possible minor bug:
The stripe ([] : xss) = stripe xss case I believe should be = [] : stripe xss.
Though this perhaps doesn't matter that much.

zipCons :: [a] -> [[a] -> [a]] -> [[a] -> [a]]
zipCons (x:xs) (ys:yss) = ((x:) . ys) : zipCons xs yss
zipCons (x:xs) ys = (x:) : zipCons xs ys
zipCons [] ys = ys

diagonalBind :: forall x a . (x -> Omega a) -> Omega x -> Omega a
diagonalBind f = Omega . foldr ($) [] . stripeBind f . runOmega where
  stripeBind :: (x -> Omega a) -> [x] -> [[a] -> [a]]
  stripeBind _ [] = []
  stripeBind g (v:xss) = case g v of
    Omega [] -> id : stripeBind g xss
    Omega (x:xs) -> (x:) : zipCons xs (stripeBind g xss)

diagonalAlt :: forall a . Omega a -> Omega a -> Omega a
diagonalAlt (Omega a) (Omega b) = Omega $ foldr ($) [] $ stripeAlt [a, b] where
  stripeAlt :: [[a]] -> [[a] -> [a]]
  stripeAlt [] = []
  stripeAlt ([]:xss) = id : stripeAlt xss
  stripeAlt ((x:xs):xss) = (x:) : zipCons xs (stripeAlt xss)

newtype Omega a = Omega { runOmega :: [a] }
  deriving newtype (Functor)

instance Applicative Omega where
  pure x = Omega [x]
  fs <*> xs = diagonalBind (<$> xs) fs

instance Alternative Omega where
  empty = Omega []
  a <|> b = diagonalAlt a b

instance Monad Omega where
  xs >>= f = diagonalBind f xs

collapse :: Term -> Omega Term
collapse Rec = pure Rec
collapse Ret = pure Ret
collapse (MkO t) = MkO <$> collapse t
collapse (MkI t) = MkI <$> collapse t
collapse (Mat l r) = Mat <$> collapse l <*> collapse r
collapse (Sup a b) = collapse a <|> collapse b

@7ab901a7933419b5
Copy link

Hi @VictorTaelin,

I'm trying to follow your recent advances – from what I'm able to comprehend (as someone who is not familiar with Haskell and interaction nets), for a set of inputs and outputs, presented as binary strings, your solution enumerates the space of functions with appropriate signature and returns the first candidate that maps given inputs into given outputs. Just in case this could help to explore the limitations of your approach, I'd like to share one possible benchmark that came to my mind – feel free to try it if it makes sense to you. I'd like to apologize in advance if my suggestion is irrelevant.

Consider a two-player game (e.g. the Iterated Prisoner's Dilemma) in which a player's move can be either 0 or 1, and a single round goes like this: both players secretly choose their moves, which are then revealed. The history available after $n$ rounds is $n$ moves by the first player and $n$ moves by the second player; this can be represented as a binary string of length $2n$, i.e. every integer from 0 to $2^{2n} - 1$ uniquely corresponds to a certain history.

Now consider a deterministic strategy for the first (w.l.o.g) player after $n$ rounds. Any such strategy can be described as follows: for each history, specify whether the next move will be 0 or 1. Since there are $2^{2n}$ histories, a strategy can be represented as a binary string of length $2^{2n}$. Here, again, integers from 0 to $2^{2^{2n}} - 1$ are in one-to-one correspondence with strategies.

The benchmark I'd like to suggest is to discover the function that applies strategy to history and returns the first player's next move. That is, given a binary string of length $2^{2n} + 2n$, which is a concatenation of a certain strategy and a certain history, the function would split it into strategy and history, then return the bit in the strategy at index int(history).

Here is a sample implementation in Python:

from random import randint

to_digits = lambda x, n: bin(x)[2:].zfill(n)
get_random = lambda n: to_digits(randint(0, 2**n - 1), n)
get_random_history = lambda n: get_random(2*n)
get_random_strategy = lambda n: get_random(2**(2*n))

def apply_strategy_to_history(strategy_and_history):
    split = next(
        x for x in range(0, len(strategy_and_history), 2)
        if len(strategy_and_history) == 2**x + x
    )
    strategy = strategy_and_history[:-split]
    history = strategy_and_history[-split:]
    return strategy[int(history, base=2)]

def make_test_case(n):
    history = get_random_history(n)
    strategy = get_random_strategy(n)
    output = apply_strategy_to_history(strategy + history)
    return f"f {strategy}{history} = {output}"

Here are a few randomized test cases for small values of $n$:

>>> for _ in range(3): print(make_test_case(1))
... 
f 011001 = 1
f 101110 = 1
f 010011 = 0
>>> for _ in range(3): print(make_test_case(2))
... 
f 00001010000001011100 = 0
f 01001000010110100111 = 0
f 10100100011110110000 = 1

I wonder if your solution will be able to discover a function equivalent to apply_strategy_to_history above, and if yes, how would the running time depend on $n$. As the input size grows exponentially in $n$, I believe this might be a challenge for an enumerative algorithm, so that it would either overfit to the given set of test cases or, for larger values of $n$, take forever to complete. Would be great if you could check.

Thanks

@VictorTaelin
Copy link
Author

Very cool example! I have a lot of things to do now (including extending this technique to the general dependently typed case), but I'll bookmark your post to try it in a future!

@AHartNtkn
Copy link

Have you looked at Vandenbroucke, Schrijvers, and Piessens work on better nondeterminism monads?

It seems to me they cover much of what SUP nodes do that's missing from the Omega monad (on its own). It may be a fairer comparison (and also may clarify what one gets for free by using SUP nodes).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment