Created March 25, 2016 07:19
Wadler's classic pattern matching algorithm implemented for a core language with Bound.
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module PatCompile where
import Bound
import Bound.Var
import Bound.Scope
import Control.Monad (ap)
import Prelude.Extras (Show1 (..), Eq1 (..))
import Data.List
import Debug.Trace
newtype Con = MkCon Int deriving (Eq, Show, Ord, Enum)
data Constant = IntLit Int
| CharLit Char
| Con Con
-- Non-pure data constants
| TrueLit
| FalseLit
| If
| EqInt
| EqChar
| Plus | Minus | Times | Div
| RaiseError
| MatchFail
deriving (Eq, Show)
data Pattern = PVar Int
| PCon Con [Pattern]
| PLit Constant
deriving (Eq, Show)
data Exp a = Var a
| Const Constant
| App (Exp a) (Exp a)
| Let Pattern (Exp a) (Scope Int Exp a)
| LetRec [Bind a] (Scope RecBV Exp a)
| Lambda Pattern (Scope Int Exp a)
| Bar (Exp a) (Exp a)
| Case (Exp a) [Alt a] -- Invariant: Exhaustive, non-overlapping, "simple"
deriving (Eq, Show, Functor, Foldable, Traversable)
data RecBV = RecBV {patternNum :: Int, varNum :: Int} deriving (Eq, Show)
data Bind a = Bind Pattern (Scope RecBV Exp a)
deriving (Eq, Show, Functor, Foldable, Traversable)
data Alt a = Alt Pattern (Scope Int Exp a)
deriving (Eq, Show, Functor, Foldable, Traversable)
instance Show1 Exp where
instance Eq1 Exp where
instance Applicative Exp where
pure = return
(<*>) = ap
instance Monad Exp where
return = Var
e >>= f =
case e of
Var a -> f a
Const c -> Const c
App l r -> App (l >>= f) (r >>= f)
Let p e1 e2 -> Let p (e1 >>= f) (e2 >>>= f)
LetRec binds e2 -> LetRec (map bindBind binds) (e2 >>>= f)
Lambda p body -> Lambda p (body >>>= f)
Bar l r -> Bar (l >>= f) (r >>= f)
Case e alts -> Case (e >>= f) (map bindAlt alts)
where bindAlt (Alt p e) = Alt p (e >>>= f)
bindBind (Bind p e) = Bind p (e >>>= f)
mkVar :: a -> Exp (Var b (Exp a))
mkVar = Var . F . Var
fillBound :: Eq b => b -> a -> Scope b Exp a -> Scope b Exp a
fillBound b new s =
Scope $ splat mkVar (\b' -> if b == b' then mkVar new else Var (B b')) s
stripScope :: (Show a, Show b) => Scope b Exp a -> Exp a
stripScope s =
case traverse id $ splat (Var . Just) (const (Var Nothing)) s of
Nothing -> error $ "Not closed: " ++ show s
Just e -> e
-- The multi-match branch we're working with.
data FlexibleAlt a = FAlt [Pattern] (Scope Int Exp a)
deriving (Functor, Foldable, Traversable, Show)
-- Compile a match on an expression with a list of branches and
-- a default into a simplified Case expression. This is actually work
-- because we require that Case expressions be exhaustive, not use
-- nested patterns, and not be overlapping.
-- To simplify things, we have parallel matching and we demand that
-- the list of list of alternatives is a rectangle.
match :: (Eq a, Show a) => [a] -> [FlexibleAlt a] -> Exp a -> Exp a
match scruts alts def
-- Base case, we've compiled all scruts.
| [] <- scruts =
foldr Bar def $ map (\(FAlt [] e) -> stripScope e) alts
-- All variable branches
| Just branches <- allVars alts,
scrut : remaining <- scruts =
let new (i, ps, s) = FAlt ps (fillBound i scrut s)
in match remaining (map new branches) def
-- Next case, first scrut is matched against only constructors
| Just branches <- allCons alts,
scrut : remaining <- scruts =
let ((_, args, _, _) : _) = branches
newVars = [0 .. length args - 1]
def' = F . Var <$> def
gathered = groupBy (\(i, _, _, _) (j, _, _, _) -> i == j) branches
branches' = map (\bs -> let (i : _, as, ps, ss) = unzip4 bs
in (i, as, ps, ss))
new (i, argss, pss, ss) =
Alt (PCon i (map PVar newVars)) . Scope $
match (map B newVars ++ map (F . Var) remaining)
[FAlt (args ++ ps) (F . Var <$> s) | (args, ps, s) <- zip3 argss pss ss]
in Case (Var scrut) (map new branches')
-- A degenerate version of the above where we're matching on literals
| Just branches <- allLits alts,
scrut : remaining <- scruts =
let gathered = groupBy (\(i, _, _) (j, _, _) -> i == j) branches
branches' = map (\bs -> let (i : _, ps, ss) = unzip3 bs
in (i, ps, ss))
new (l, pss, ss) =
Alt (PLit l) . abstract (const Nothing) $
match remaining [FAlt ps s | (ps, s) <- zip pss ss] def
in Case (Var scrut) (map new branches')
-- A final case, we split apart overlapping patterns into chunks of
-- nonoverlapping patterns and process them separately.
| chunks <- splitChunks alts = foldr (match scruts) def chunks
where allVars [] = Just []
allVars (FAlt (PVar i : ps) s : alts) = ((i, ps, s) :) <$> allVars alts
allVars _ = Nothing
allCons [] = Just []
allCons (FAlt (PCon c args : ps) s : alts) =
((c, args, ps, s) :) <$> allCons alts
allCons _ = Nothing
allLits [] = Just []
allLits (FAlt (PLit l : ps) s : alts) =
((l, ps, s) :) <$> allLits alts
allLits _ = Nothing
splitChunks = groupBy $ \a b -> case (a, b) of
(FAlt (PVar _ : _) _, FAlt (PVar _ : _) _) -> True
(FAlt (PCon _ _ : _) _, FAlt (PCon _ _ : _) _) -> True
(FAlt (PLit _ : _) _, FAlt (PLit _ : _) _) -> True
_ -> False
abstractF :: [String] -> Exp String -> Scope Int Exp String
abstractF vars = abstract (flip elemIndex vars)
app :: Exp a -> [Exp a] -> Exp a
app = foldl App
con :: Con -> Exp a
con = Const . Con
instance Num Con where
fromInteger = MkCon . fromIntegral
test =
match ["hello", "world"]
[ FAlt [PCon 0 [], PVar 0] $ abstractF ["x"] (Var "x")
, FAlt [PVar 0, PCon 0 []] $ abstractF ["x"] (Var "x")
, FAlt [PCon 1 [PVar 0, PVar 1], PCon 1 [PVar 2, PVar 3]]
. abstractF ["x", "xs", "y", "ys"]
$ app (con 1) [Var "x", app (con 1) [Var "y", app (Var "rec") [Var "xs", Var "ys"]]]]
(Const MatchFail)
