Skip to content

Instantly share code, notes, and snippets.

@aradarbel10
Last active June 14, 2023 20:52
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aradarbel10/9f5259f4b52b9dae804483517a1cf868 to your computer and use it in GitHub Desktop.
Save aradarbel10/9f5259f4b52b9dae804483517a1cf868 to your computer and use it in GitHub Desktop.
Type directed program synthesis for STLC
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use list comprehension" #-}
import Data.IORef ( IORef, newIORef, readIORef, writeIORef )
import GHC.IO ( unsafePerformIO )
import Control.Applicative ( Alternative(..) )
import Control.Monad ( when )
import Debug.Trace
--- global fresh name source ---
type Name = String
freshi :: IORef Int
freshi = unsafePerformIO $ newIORef 0
{-# NOINLINE freshi #-}
nexti :: () -> IO Int
nexti () = do
i <- readIORef freshi
writeIORef freshi (i + 1)
return i
freshen :: Name -> Name
freshen str = str ++ show (unsafePerformIO $ nexti ())
--- nondeterministic computations ---
type Nondet a = [a]
--- language description ---
data Ty = Base Name | Arrow Ty Ty | Prod Ty Ty
deriving (Show, Eq)
data Vl = Neut Ne | Lam Name Vl | Pair Vl Vl
data Ne = Var Name | App Ne Vl | Fst Ne | Snd Ne
--- pretty printing ---
instance Show Vl where
show :: Vl -> String
show (Neut n) = "(" ++ show n ++ ")"
show (Lam x v) = "(\\" ++ x ++ ". " ++ show v ++ ")"
show (Pair v1 v2) = "(" ++ show v1 ++ ", " ++ show v2 ++ ")"
instance Show Ne where
show :: Ne -> String
show (Var x) = x
show (App n v) = show n ++ " " ++ show v
show (Fst n) = show n ++ " .1"
show (Snd n) = show n ++ " .2"
--- typing contexts ---
type Ctx = [(Name, Ty)]
assume :: Ctx -> Name -> Ty -> Ctx
assume ctx x t = (x, t) : ctx
--- eliminator shapes ---
-- this just keeps track of the shape, not the eliminated term itself,
-- but once we have the shape we can very easily apply concrete eliminators later on.
-- note the *head* of the shape is the eliminator applied *first*.
data Elim = EApp Ty | EFst | ESnd
deriving Show
type Shape = [Elim]
--- program synthesis ---
-- this is the main part of the code, implementing type-directed program synthesis, or equivalently
-- type directed proof search. it's based on two functions, `synth` and `search`, which roughly follow
-- the same pattern as bidirectional typechecking but are relationally dual.
-- `synth` takes a context and a goal type, and tries to build an introduction form of that type.
-- it is the program-synthesis analogue of bidi's `check`.
synth :: Ctx -> Ty -> Nondet Vl
synth ctx (Arrow t1 t2) = do
let x = freshen "x"
body <- search (assume ctx x t1) t2
return $ Lam x body
synth ctx (Prod t1 t2) = do
e1 <- search ctx t1
e2 <- search ctx t2
return $ Pair e1 e2
-- in the last case, this type has no introduction forms and thus cannot be synthesized
synth ctx _ = []
-- `search` takes a context and a goal type, and tries to search variables in the context fitting that goal,
-- possibly applying eliminators to potential variables.
-- it is the program-synthesis analogue of bidi's `infer`, and similarly falls back to `synth`.
search :: Ctx -> Ty -> Nondet Vl
search ctx goal =
searchCtx ctx <|> synth ctx goal
where
-- the context is searched in order
searchCtx :: Ctx -> Nondet Vl
searchCtx [] = []
searchCtx ((x, s):rest) = (Neut <$> searchEntry x s) <|> searchCtx rest
-- each context entry is searched against by trying to either using it directly as a variable,
-- or applying some eliminators on it first. namely, the result will always be a neutral.
searchEntry :: Name -> Ty -> Nondet Ne
searchEntry x t = do
shape <- reachable goal t
applyShape (Var x) shape
applyShape :: Ne -> Shape -> Nondet Ne
applyShape n [] = return n
applyShape n (EApp t1 : shape) = do
arg <- search ctx t1
applyShape (App n arg) shape
applyShape n (EFst : shape) = applyShape (Fst n) shape
applyShape n (ESnd : shape) = applyShape (Snd n) shape
-- before starting to apply eliminators on a context entry, we should check if the goal is even
-- reachable from its type. this is important because, eg, mindlessly eliminating function types will
-- require repeatedly synthesizing the argument of an application, but that loops infinitely.
reachable :: Ty -> Ty -> Nondet Shape
reachable goal t =
-- the goal might be immediately reachable
(if goal == t then [[]] else [])
-- otherwise maybe we can get closer to the goal by applying some eliminators
<|> case t of
Arrow t1 t2 -> (EApp t1 :) <$> reachable goal t2
Prod t1 t2 -> ((EFst :) <$> reachable goal t1) <|> ((ESnd :) <$> reachable goal t2)
_ -> []
--- aesthetic helpers ---
w, x, y, z :: Ty
(w, x, y, z) = (Base "W", Base "X", Base "Y", Base "Z")
infixr 5 ~>
(~>) :: Ty -> Ty -> Ty
(~>) = Arrow
--- examples ---
runExample :: Ctx -> Ty -> IO ()
runExample ctx t = do
-- default cap
let sols = take 7 $ search ctx t
print sols
main = do
putStrLn "welcome to proof search!"
-- finding appropriate variables in scope
runExample [("a", x)] x
runExample [("a", x), ("f", y ~> z)] (y ~> z)
-- synthesizing under a binder
runExample [] (x ~> x)
runExample [] (x ~> y ~> x)
-- eliminating assumptions from the context
runExample [("a", Prod y z)] z
runExample [("a", x), ("f", x ~> y)] y
runExample [("a", x), ("b", x), ("f", x ~> x ~> y)] y
runExample [("h", y ~> z), ("g", x ~> y), ("f", w ~> x)] (w ~> z)
-- repeatedly composing a function with itself
runExample [("a", x ~> x)] (x ~> x)
-- applying a higher order function
runExample [("f", x ~> y), ("g", (x ~> y) ~> y ~> z)] (x ~> z)
-- this example came out really interesting! exercise for the reader: try to come up with as many
-- of your own solutions (destinct up to beta-eta equivalence) before looking at the synthesis results.
-- misc
runExample [] (Prod x y ~> Prod y x)
runExample [("b", Prod (Prod x (y ~> z)) (Prod (Prod y z) x))] z
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment