Skip to content

Instantly share code, notes, and snippets.

@gatlin
Last active August 20, 2022 01:27
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save gatlin/c56a12abc386fc9df498ec2d7292a9a5 to your computer and use it in GitHub Desktop.
Save gatlin/c56a12abc386fc9df498ec2d7292a9a5 to your computer and use it in GitHub Desktop.
{-|
- Example of using free constructions to build a flexible little compiler.
-
- The goal here is not necessarily efficiency but readability and flexibility.
-
- The language grammar is represented by an ADT; however, instead of
- recursively referring to itself it instead references a type variable.
-
- We derive instances of 'Functor' and 'Traversable' for this type.
-
- A free monad over this grammar allows us to build expressions legibly and
- simply during the parsing stage.
-
- Then we can transform this into a cofree comonad over the same grammar type
- which, along with the tools from the @comonad@ package, allow us to annotate
- our expressions easily.
-
- At the bottom the beginnings of a simple type inferencing scheme are
- sketched out to demonstrate the use of this technique
-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TemplateHaskell#-}
{- cabal:
build-depends: base, comonad, free, text, attoparsec, mtl, containers, deriving-compat
-}
module Main where
import Control.Monad (join, (>=>), forM, foldM, mapAndUnzipM)
import Control.Monad.Free
import Control.Monad.State
import Control.Monad.Except
import Control.Comonad
import Control.Comonad.Cofree
import Control.Applicative
import Data.Functor.Identity
import Data.Traversable
import Data.Monoid ((<>))
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Attoparsec.Text
import Data.Map (Map)
import qualified Data.Map as M
import Text.Show.Deriving
import Data.Ord.Deriving
import Data.Eq.Deriving
-- * Core syntax representation
type Symbol = String
-- | Our simple expression language represented as a 'Traversable'.
data CoreAst a
= IntC Integer
| FloatC Double
| BoolC Bool
| SymC Symbol
| AppC a [a]
| FunC [Symbol] a
| IfC a a a
deriving ( Functor
, Foldable
, Traversable
, Show
, Eq
, Ord
)
-- ** Convenience constructors for our expression language
cInt :: (MonadFree CoreAst m) => Integer -> m a
cInt n = liftF $ IntC n
cFloat :: (MonadFree CoreAst m) => Double -> m a
cFloat n = liftF $ FloatC n
cBool :: (MonadFree CoreAst m) => Bool -> m a
cBool b = liftF $ BoolC b
cSym :: (MonadFree CoreAst m) => Symbol -> m a
cSym s = liftF $ SymC s
cApp :: (MonadFree CoreAst m) => m a -> [m a] -> m a
cApp f a = join . liftF $ AppC f a
cFun :: (MonadFree CoreAst m) => [Symbol] -> m a -> m a
cFun a b = join . liftF $ FunC a b
cIf :: (MonadFree CoreAst m) => m a -> m a -> m a -> m a
cIf c t e = join . liftF $ IfC c t e
type CoreExpr = Free CoreAst
-- * Parsing.
-- | Numbers can either be integers or doubles in our language
num_parser :: Parser (CoreExpr a)
num_parser = do
whole_part <- many1 digit
mDot <- peekChar
case mDot of
Just '.' -> do
char '.'
mantissa <- many1 digit
return $ cFloat $ read $ whole_part ++ ['.'] ++ mantissa
_ -> return $ cInt (read whole_part)
-- | Legal characters for a 'Symbol'
symchars :: String
symchars = "=<>.!@#$%^&*{}[]+-/\\"
symchar :: Parser Char
symchar = satisfy $ \c -> elem c symchars
sym :: Parser String
sym = do
firstChar <- letter <|> symchar
otherChars <- many' $ letter <|> digit <|> symchar
return $ firstChar:otherChars
-- | Symbol parser
sym_parser :: Parser (CoreExpr a)
sym_parser = do
symbol <- sym
return $ cSym symbol
-- | Parse boolean values
bool_parser :: Parser (CoreExpr a)
bool_parser = do
char '#'
bv <- char 't' <|> char 'f'
case bv of
't' -> return $ cBool True
'f' -> return $ cBool False
-- | Parse operator application
app_parser :: Parser (CoreExpr a)
app_parser = do
(op:erands) <- expr_parser `sepBy1` (many space)
return $ cApp op erands
lam_parser :: Parser Text
lam_parser = (string "\\")
-- | Parse lamba abstraction
fun_parser :: Parser (CoreExpr a)
fun_parser = do
char '\\'
skipSpace
char '('
skipSpace
args <- sym `sepBy` (many space)
skipSpace
char ')'
skipSpace
body <- expr_parser
return $ cFun args body
-- | Parse if-expressions
if_parser :: Parser (CoreExpr a)
if_parser = do
string "if"
skipSpace
c <- expr_parser
skipSpace
t <- expr_parser
skipSpace
e <- expr_parser
return $ cIf c t e
parens :: Parser a -> Parser a
parens p = do
char '('
skipSpace
wrapped_expr <- p
skipSpace
char ')'
return wrapped_expr
expr_parser :: Parser (CoreExpr a)
expr_parser = (parens fun_parser)
<|> (parens if_parser)
<|> bool_parser
<|> sym_parser
<|> num_parser
<|> (parens app_parser)
-- | Parse @Text@ values into our core expression language.
-- * Annotated expressions
-- | A representation of our core syntax tree which annotates every node with a
-- value of some arbitrary, user-defined type.
type AnnotatedExpr = Cofree CoreAst
-- * EXAMPLE: Shitty, incomplete type inference
-- This is not a full type inferencing step because I am lazy. However, it is
-- an incomplete beginning sketch of the ideas laid out in:
--
-- https://brianmckenna.org/blog/type_annotation_cofree
-- | Our simple type language
data Type
= TLambda [Type] -- ^ "What if this list is empty?" Congrats, we support Void!
| TVar Int -- ^ Type variables are identified by a unique number
| TInt
| TFloat
| TBool
deriving (Show)
-- | Simple equality constraints for types
data Constraint = EqualityConstraint Type Type
deriving (Show)
data TypeResult = TypeResult
{ constraints :: [Constraint]
, assumptions :: Map String [Type]
} deriving (Show)
instance Semigroup TypeResult where
a <> b = TypeResult {
constraints = constraints a <> constraints b,
assumptions = assumptions a <> assumptions b
}
instance Monoid TypeResult where
mempty = TypeResult mempty mempty
mappend a b = TypeResult {
constraints = constraints a `mappend` constraints b,
assumptions = assumptions a `mappend` assumptions b
}
$(deriveEq1 ''CoreAst)
$(deriveShow1 ''CoreAst)
$(deriveOrd1 ''CoreAst)
-- * The Compiler monad
-- | The errors we might get while compiling
data CompilerError
= ParserError String
| InternalError -- ^ I couldn't think of a better name for this
| TypeError
deriving (Eq, Show)
-- | Our Compiler will manage some state. This foreshadows our example
-- use-case...
data CompilerState = CompilerState
{ typeVarId :: Int
, memo :: Map (AnnotatedExpr ()) (Type, TypeResult)
}
defaultCompilerState :: CompilerState
defaultCompilerState = CompilerState {
typeVarId = 0,
memo = mempty
}
-- | Remember: failure is always an option.
type Compiler = StateT CompilerState (Except CompilerError)
runCompiler :: CompilerState -> Compiler a -> Either CompilerError a
runCompiler state compiler = runExcept $ evalStateT compiler state
parse_expr :: Text -> Compiler (CoreExpr ())
parse_expr t = case parseOnly expr_parser t of
Left err -> throwError $ ParserError err
Right result -> return result
-- |
-- A well-formed syntax tree never actually uses the 'Pure' constructor of the
-- free monad representation. This function takes advantage of the @Compiler@
-- monad and returns an @AnnotatedExpr@ or an error of some sort.
-- Why even use @Free@ at all, then? Chiefly, it comes in the same package as
-- @Cofree@. There may also be clever uses of the @Pure@ constructor that have
-- yet to be considered.
annotate :: Traversable f => Free f () -> Compiler (Cofree f ())
annotate (Pure _) = throwError InternalError
annotate (Free m) = fmap (() :<) $ traverse annotate m
-- | Generate fresh type variables on demand.
freshTypeVar :: Compiler Type
freshTypeVar = do
v <- gets typeVarId
modify $ \s -> s { typeVarId = v + 1 }
return $ TVar v
memoizedTC
:: (AnnotatedExpr () -> Compiler (Type, TypeResult))
-> AnnotatedExpr ()
-> Compiler (Type, TypeResult)
memoizedTC f c = gets memo >>= maybe memoize return . M.lookup c where
memoize = do
r <- f c
modify $ \s -> s { memo = M.insert c r $ memo s }
return r
-- | Basic type inference mechanism which infers constraints for expressions
-- Literals
infer :: AnnotatedExpr () -> Compiler (Type, TypeResult)
infer (_ :< (IntC _)) = return (TInt, mempty)
infer (_ :< (FloatC _)) = return (TFloat, mempty)
infer (_ :< (BoolC _)) = return (TBool, mempty)
-- Symbols
infer (_ :< (SymC s)) = do
typeVar <- freshTypeVar
return (typeVar, TypeResult {
constraints = [],
assumptions = M.singleton s [typeVar]
})
-- Lambda abstraction
infer (_ :< (FunC args body)) = do
argVars <- forM args $ \_ -> freshTypeVar
br <- memoizedTC infer body
bt <- foldM
(\tr (arg, var) -> return $ TypeResult {
constraints = maybe [] (fmap $ EqualityConstraint var)
(M.lookup arg . assumptions $ tr),
assumptions = M.delete arg . assumptions $ tr
})
(snd br)
(zip args argVars)
return (TLambda ( argVars ++ [fst br]), TypeResult {
constraints = constraints (snd br) `mappend` (constraints bt),
assumptions = assumptions bt
})
-- Lambda application
infer (_ :< (AppC op erands)) = do
typeVar <- freshTypeVar
op' <- memoizedTC infer op
(vars, erands') <- mapAndUnzipM (memoizedTC infer) erands
erands'' <- foldM (\a b -> return $ a <> b) mempty erands'
return (typeVar, (snd op') <> erands'' <> TypeResult {
constraints = [EqualityConstraint (fst op') $ TLambda $ vars ++ [typeVar]],
assumptions = mempty
})
-- | Turn 'infer' into a Kleisli arrow that extends to the whole expression
-- tree.
generateConstraints :: AnnotatedExpr () -> Compiler (AnnotatedExpr (Type, TypeResult))
generateConstraints = sequenceA . extend infer
-- | Solve constraints and produce a map from type variable to type
solveConstraints :: [Constraint] -> Compiler (Map Int Type)
solveConstraints =
foldl (\b a -> liftM2 mappend (solve b a) b) $ return M.empty
where solve maybeSubs (EqualityConstraint a b) = do
subs <- maybeSubs
mostGeneralUnifier (substitute subs a) (substitute subs b)
-- | Unify type variables into a mapping from type variable number to a type
mostGeneralUnifier :: Type -> Type -> Compiler (Map Int Type)
mostGeneralUnifier (TVar i) b = return $ M.singleton i b
mostGeneralUnifier a (TVar i) = return $ M.singleton i a
mostGeneralUnifier TInt TInt = return mempty
mostGeneralUnifier TFloat TFloat = return mempty
mostGeneralUnifier TBool TBool = return mempty
mostGeneralUnifier (TLambda []) (TLambda []) = return mempty
mostGeneralUnifier (TLambda (x:xs)) (TLambda (y:ys)) = do
su1 <- mostGeneralUnifier x y
su2 <- mostGeneralUnifier (TLambda $ fmap (substitute su1) xs)
(TLambda $ fmap (substitute su1) ys)
return $ su2 <> su1
mostGeneralUnifier _ _ = throwError TypeError
substitute :: Map Int Type -> Type -> Type
substitute subs v@(TVar i) = maybe v (substitute subs) $ M.lookup i subs
substitute subs (TLambda vs) = TLambda $ fmap (substitute subs) vs
substitute _ t = t
-- | An example expression in our shitty little lisp.
test_expr_1 :: Text
test_expr_1 = "((\\ (x) (* x 2)) 3)"
compile :: Text -> Compiler (Map Int Type, AnnotatedExpr Type)
compile inp = do
result@(r :< _) <- parse_expr >=> annotate >=> generateConstraints $ inp
subs <- solveConstraints . constraints $ snd r
let expr = fmap (substitute subs . fst) result
return (subs, expr)
main :: IO ()
main = do
let result = runCompiler defaultCompilerState $ compile test_expr_1
case result of
Left err -> putStrLn . show $ err
Right (subs, expr) -> do
putStrLn $ "Subs: " ++ (show subs)
putStrLn $ "Expr: " ++ (show expr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment