Skip to content

Instantly share code, notes, and snippets.

@cheery
Created November 22, 2023 12:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cheery/b3e9eb1058d52267021c94f42f1221a9 to your computer and use it in GitHub Desktop.
Save cheery/b3e9eb1058d52267021c94f42f1221a9 to your computer and use it in GitHub Desktop.
Pattern unification
module CatPu where
import Control.Applicative (Alternative (..))
import Control.Monad (MonadPlus (..), foldM, forM)
import Control.Monad.State
import Control.Monad.Except
import Data.List (intersect, elemIndex)
type Goal = SolverState -> Stream SolverState
data Tm = Lam Tm
| Var Int [Tm]
| Meta Int Ren
deriving (Show, Eq)
type Ren = [Int]
data Ty = Iot [Ty] deriving (Show, Eq)
type Ctx = [Ty]
type MCtx = [(Int,Ty)]
type Subs = [(Int,Tm)]
type SolverState = (MCtx, Subs, Int)
data Blocker
= Occurs
| InversionFailed
| Different
deriving (Show)
newtype Solver a = Solver (StateT SolverState (Except Blocker) a)
runSolver :: SolverState -> Solver a -> Either Blocker (a, SolverState)
runSolver st (Solver m) = runExcept (runStateT m st)
execSolver :: SolverState -> Solver a -> Either Blocker SolverState
execSolver st (Solver m) = runExcept (execStateT m st)
deriving instance Applicative Solver
deriving instance Functor Solver
deriving instance Monad Solver
deriving instance MonadError Blocker Solver
deriving instance MonadState SolverState Solver
iot = Iot []
walk :: Tm -> SolverState -> Tm
walk u (_, s, _) = go u
where go :: Tm -> Tm
go (Lam u) = Lam (go u)
go (Var x e) = Var x (fmap go e)
go (Meta i e) | Just z <- lookup i s = go (mapp z e)
go (Meta i e) = Meta i e
mapp :: Tm -> [Int] -> Tm
mapp u [] = u
mapp (Lam u) (x:xs) = mapp (ren 0 x u) xs
where ren :: Int -> Int -> Tm -> Tm
ren x y (Lam u) = Lam (ren (x+1) (y+1) u)
ren x y (Var i e) | (i == x) = Var y (fmap (ren x y) e)
ren x y (Var i e) = Var i (fmap (ren x y) e)
ren x y (Meta k p) = Meta k (fmap (\i -> if i == x then y else i) p)
(===) :: Tm -> Tm -> Goal
(===) t1 t2 st = case execSolver st (unify (walk t1 st) (walk t2 st)) of
Left blocker -> Nil
Right st' -> pure st'
fresh' :: Ty -> SolverState -> (Int, SolverState)
fresh' ty (m,s,i) = (i, ((i,ty):m, s, i+1))
mlookup :: Int -> SolverState -> Ty
mlookup k (m,s,i) = let Just ty = lookup k m in ty
extS :: Int -> Tm -> Solver ()
extS k u = do
(m,s,i) <- get
put (m,(k,u):s,i)
lams :: [Ty] -> Tm -> Tm
lams [] u = u
lams (x:xs) u = Lam (lams xs u)
unify :: Tm -> Tm -> Solver ()
unify (Lam t) (Lam u) = unify t u
unify (Var x e) (Var y e') | (x == y) && (length e == length e') = do
forM_ (zip e e') $ \(a, b) -> do
a' <- gets (walk a)
b' <- gets (walk b)
unify a' b'
unify (Meta k p) (Meta k' p') | (k == k') = do
Iot tys <- gets (mlookup k)
-- discard renamings that don't match.
let mi = length tys - 1
let q = filter (\(_,_,k) -> k) $ zip3 (reverse [0..mi]) tys (fmap (\(x,y) -> x == y) (zip p p'))
let ty = Iot (fmap (\(_,ty,_) -> ty) q)
let vec = fmap (\(i,ty,_) -> i) q
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j vec))
unify (Meta k p) (Meta k' p') = do
-- identify set of variables visible in both renamings.
Iot tys <- gets (mlookup k)
let sect = intersect p p'
let tyvec i | Just q <- elemIndex i p = tys!!q
let vec p i | Just q <- elemIndex i p = length p - q - 1
let ty = Iot (fmap tyvec sect)
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j (fmap (vec p) sect)))
extS k' (lams tys (Meta j (fmap (vec p') sect)))
unify (Meta k p) e = assign k p e
unify e (Meta k p) = assign k p e
unify _ _ = throwError Different
assign :: Int -> Ren -> Tm -> Solver ()
assign k p u = do
Iot tys <- gets (mlookup k)
when (occurs k u) (throwError Occurs)
-- invert 'p' and check that u@(Var i e) contains only variables
-- defined in 'p'. Replace and assign.
let m = zip p (reverse [0..length p - 1])
u' <- replace m 0 u
extS k (lams tys u')
replace_var :: [(Int,Int)] -> Int -> Int -> Solver Int
replace_var m d i | i < d = pure i
| otherwise = case lookup (i-d) m of
Nothing -> throwError InversionFailed
Just k -> pure (k + d)
replace :: [(Int,Int)] -> Int -> Tm -> Solver Tm
replace m d (Lam u) = do
fmap Lam (replace m (d+1) u)
replace m d (Var x e) = do
y <- replace_var m d x
e' <- forM e (replace m d)
pure (Var y e')
replace m d (Meta k p) = do
catchError (do p' <- forM p (replace_var m d)
pure (Meta k p'))
(\_ -> do prune ([0..d-1] <> fmap ((+d) . fst) m) k p
u <- gets (walk (Meta k p))
replace m d u)
-- prune away innard variables (by assigning new meta variables)
prune :: [Int] -> Int -> Ren -> Solver ()
prune m k p = do
Iot tys <- gets (mlookup k)
let sect = intersect m p
tyvec i | Just q <- elemIndex i p = tys!!q
vec p i | Just q <- elemIndex i p = length p - q - 1
ty = Iot (fmap tyvec sect)
-- introduce new meta that renames.
j <- state (fresh' ty)
extS k (lams tys (Meta j (fmap (vec p) sect)))
occurs :: Int -> Tm -> Bool
occurs i (Meta j p) = (i == j)
occurs i (Var j e) = foldl (||) False (fmap (occurs i) e)
occurs i (Lam u) = occurs i u
fresh :: Ty -> ((Ren -> Tm) -> Goal) -> Goal
fresh ty f st
= let (c, st') = fresh' ty st
in f (Meta c) st'
disj :: Goal -> Goal -> Goal
disj g1 g2 st = g1 st `mplus` g2 st
conj :: Goal -> Goal -> Goal
conj g1 g2 st = g1 st >>= g2
data Stream a = Nil
| Cons a (Stream a)
| Delayed (Stream a)
deriving (Eq, Show)
instance Monad Stream where
Nil >>= _ = Nil
x `Cons` xs >>= f = f x `mplus` (xs >>= f)
Delayed s >>= f = Delayed (s >>= f)
instance MonadPlus Stream where
mzero = empty
mplus = (<|>)
instance Alternative Stream where
empty = Nil
Nil <|> xs = xs
(x `Cons` xs) <|> ys = x `Cons` (ys <|> xs)
Delayed xs <|> ys = Delayed (ys <|> xs)
instance Functor Stream where
fmap _ Nil = Nil
fmap f (a `Cons` s) = f a `Cons` fmap f s
fmap f (Delayed s) = Delayed (fmap f s)
instance Applicative Stream where
pure a = a `Cons` Nil
Nil <*> _ = Nil
_ <*> Nil = Nil
(f `Cons` fs) <*> as = fmap f as <|> (fs <*> as)
Delayed fs <*> as = Delayed (fs <*> as)
failure :: Goal
failure _ = Nil
delay :: Goal -> Goal
delay = fmap Delayed
initialState :: SolverState
initialState = ([],[],0)
takeS :: Int -> Stream a -> [a]
takeS 0 _ = []
takeS n Nil = []
takeS n (Delayed s) = takeS n s
takeS n (a `Cons` as) = a : takeS (n - 1) as
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment