Skip to content

Instantly share code, notes, and snippets.

@Cedev
Created January 11, 2015 05:47
Show Gist options
  • Save Cedev/89b0ef2b196e25e1cb2c to your computer and use it in GitHub Desktop.
Save Cedev/89b0ef2b196e25e1cb2c to your computer and use it in GitHub Desktop.
Compiler from ArrowLike primitive recursive functions to LLVM
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}
import GHC.Exts (Constraint)
import Data.Proxy
import Data.Word
import Control.PrimRec
import Prelude hiding (id, (.), fst, snd, succ)
import LLVM.General.AST
import qualified LLVM.General.AST.Constant as Constant
import qualified LLVM.General.AST.IntegerPredicate as ICmp
import qualified LLVM.General.AST.Linkage as Linkage
import qualified LLVM.General.AST.Visibility as Visibility
import qualified LLVM.General.AST.CallingConvention as CallingConvention
import LLVM.General.AST.AddrSpace
import LLVM.General.PrettyPrint
import qualified Data.Map as Map
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Functor.Identity
import Control.Applicative
import LLVM_General_Pure_PrettyPrint
data Dict c where
Dict :: c => Dict c
type RegisterableDict a = Dict (Registerable a, RegisterableCtx a)
data Registers r where
Registers :: Registerable r => RegisterRep r -> Registers r
class Monad m => MonadLLVM m where
getName :: m Name
instruct :: [Named Instruction] -> m ()
define :: Type -> Name -> [Parameter] -> m () -> m ()
block :: Name -> m (Named Terminator) -> m ()
--branch ::
class Registerable a where
type RegisterRep a :: *
type RegisterableCtx a :: Constraint
registerableDict :: Proxy a -> RegisterableDict a
types :: Proxy a -> [Type]
toOperands :: Registers a -> [Operand]
readOperands :: [Operand] -> (Registers a, [Operand])
instance Registerable Nat where
type RegisterRep Nat = Operand
type RegisterableCtx Nat = ()
registerableDict _ = Dict
types _ = [IntegerType 64]
toOperands (Registers o) = [o]
readOperands (o:os) = (Registers o, os)
instance Registerable () where
type RegisterRep () = ()
type RegisterableCtx () = ()
registerableDict _ = Dict
types _ = []
toOperands _ = []
readOperands os = (Registers (), os)
instance (Registerable a, Registerable b) => Registerable (a, b) where
type RegisterRep (a, b) = (Registers a, Registers b)
type RegisterableCtx (a, b) = (Registerable a, Registerable b)
registerableDict _ = Dict
types _ = types (Proxy :: Proxy a) ++ types (Proxy :: Proxy b)
toOperands (Registers (a, b)) = toOperands a ++ toOperands b
readOperands os = let (a, os') = readOperands os
(b, os'') = readOperands os'
in (Registers (a, b),os'')
fromOperands :: Registerable a => [Operand] -> Registers a
fromOperands os = let (a, []) = readOperands os in a
data RegisterArrow m x y where
RegisterArrow :: (Registerable x, Registerable y) => m (Registers x -> m (Registers y)) -> RegisterArrow m x y
data PRFCompiled m a b where
BlockLike :: (RegisterableDict a -> RegisterArrow m a b) -> PRFCompiled m a b
{-
:: (Monad m) => (Registers b -> m (Registers c)) -> (Registers a -> m (Registers b)) -> (Registers a -> m (Registers c))
f %%% g = \a -> g a >>=
-}
rarrowDict :: forall m x y. RegisterArrow m x y -> Dict (Registerable x, Registerable y, RegisterableCtx x, RegisterableCtx y)
rarrowDict (RegisterArrow _) =
case registerableDict (Proxy :: Proxy x)
of Dict ->
case registerableDict (Proxy :: Proxy y)
of Dict -> Dict
fstDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict a
fstDict Dict = case registerableDict (Proxy :: Proxy a) of Dict -> Dict
sndDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict b
sndDict Dict = case registerableDict (Proxy :: Proxy b) of Dict -> Dict
instance (Monad m) => Category (PRFCompiled m) where
id = BlockLike $ \Dict -> RegisterArrow . return $ return
BlockLike df . BlockLike dg = BlockLike $ \Dict ->
case dg Dict
of rg@(RegisterArrow mg) ->
case rarrowDict rg
of Dict ->
case df Dict
of RegisterArrow mf -> RegisterArrow $ do
g <- mg
f <- mf
return (\a -> g a >>= f)
instance (Monad m) => ArrowLike (PRFCompiled m) where
fst = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (regs, _)) -> return regs
snd = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (_, regs)) -> return regs
BlockLike df &&& BlockLike dg = BlockLike $ \Dict ->
case (df Dict, dg Dict)
of (RegisterArrow mf, RegisterArrow mg) -> RegisterArrow $ do
f <- mf
g <- mg
return $ \regs -> do
rf <- f regs
rg <- g regs
return $ Registers (rf, rg)
instance (MonadLLVM m) => PrimRec (PRFCompiled m) where
zero = BlockLike $ \Dict -> RegisterArrow . return $ \_ -> return . Registers . ConstantOperand . Constant.Int 64 $ 0
succ = BlockLike $ \Dict -> RegisterArrow . return $ regSucc
where
regSucc (Registers op) = (>>= return . Registers) . opSucc $ op
opSucc (ConstantOperand (Constant.Int b v)) = return . ConstantOperand . Constant.Int b $ v+1
opSucc (op) = bind (IntegerType 64) $ Add False False op (ConstantOperand $ Constant.Int 64 1) []
prec (BlockLike df) (BlockLike dg) = BlockLike $ \d@Dict ->
case df $ sndDict d
of (RegisterArrow mf) ->
case dg Dict
of (RegisterArrow mg) -> RegisterArrow $ do
f <- mf
g <- mg
defineRecursive $ \go read ret -> do
brName <- getName
zeroName <- getName
succName <- getName
rs@(Registers (Registers n, e)) <- read
block brName $ do
cmp <- bind (IntegerType 1) $ ICmp ICmp.EQ n (ConstantOperand $ Constant.Int 64 0) []
return . Do $ CondBr cmp zeroName succName []
block zeroName $ do
c <- f e
ret c
block succName $ do
pred <- bind (IntegerType 64) $ Sub False False n (ConstantOperand $ Constant.Int 64 1) []
c <- go (Registers (Registers pred,e))
c' <- g (Registers (c,rs))
ret c'
defineRecursive :: forall x y m. (Registerable x, Registerable y, MonadLLVM m) =>
(
(Registers x -> m (Registers y)) -> -- recursive call
m (Registers x) -> -- read parameters
(Registers y -> m (Named Terminator)) -> -- return results
m () -- function body
) ->
m (Registers x -> m (Registers y)) -- call function
defineRecursive def = do
functionName <- getName
inPtrName <- getName
outPtrName <- getName
let
inType = StructureType False $ types (Proxy :: Proxy x)
outType = StructureType False $ types (Proxy :: Proxy y)
call regs = do
inPtr <- allocPtr inType
outPtr <- allocPtr outType
writePtr inPtr regs
instruct [
Do (
Call False CallingConvention.C []
(
Right . ConstantOperand $
Constant.GlobalReference
(FunctionType VoidType [ptr outType, ptr inType] False)
functionName
)
[(outPtr, []), (inPtr, [])]
[] []
)
]
readPtr outPtr
ret regs = do
writePtr (LocalReference (ptr outType) outPtrName) regs
return (Do (Ret Nothing []))
read = readPtr (LocalReference (ptr inType) inPtrName)
define VoidType functionName [Parameter (ptr outType) outPtrName [], Parameter (ptr inType) inPtrName []] $ do
def call read ret
return call
ptr x = PointerType x (AddrSpace 0)
allocPtr :: (MonadLLVM m) => Type -> m (Operand)
allocPtr t = bind (ptr t) $ Alloca t Nothing 0 []
elemPtrs :: (MonadLLVM m) => Operand -> [Type] -> m ([Operand])
elemPtrs struct ts = do
sequence $ zipWith getElemPtr ts [0..]
where
getElemPtr t n = bind (ptr t) $ GetElementPtr True struct [ConstantOperand $ Constant.Int 32 0, ConstantOperand $ Constant.Int 32 n] []
readPtr :: forall r m. (Registerable r, MonadLLVM m) => Operand -> m (Registers r)
readPtr struct = do
let ts = types (Proxy :: Proxy r)
elems <- elemPtrs struct ts
ops <- sequence $ zipWith load ts elems
return . fromOperands $ ops
where
load t e = bind t $ Load False e Nothing 0 []
writePtr :: forall r m. (Registerable r, MonadLLVM m) => Operand -> Registers r -> m ()
writePtr struct regs = do
let ops = toOperands regs
ts = types (Proxy :: Proxy r)
elems <- elemPtrs struct ts
sequence_ $ zipWith store elems ops
where
store e op = instruct [Do $ Store False e op Nothing 0 []]
{-
newtype FirstDefined k v = FirstDefined (Map.Map k v)
instance (Ord k) => Monoid (FirstDefined k v) where
mempty = FirstDefined Map.empty
(FirstDefined a) `mappend` (FirstDefined b) = FirstDefined (Map.union a b)
-}
newtype LLVMT m a = LLVMT {unLLVMT :: StateT LLVMState m a}
deriving (Functor, Applicative, Monad)
data LLVMState = LLVMState {
globals :: Map.Map Name Global,
blocks :: [BasicBlock],
instructions :: [Named Instruction],
names :: [Name]
}
runLLVMT :: LLVMT m a -> LLVMState -> m (a, LLVMState)
runLLVMT = runStateT . unLLVMT
bind :: (MonadLLVM m) => Type -> Instruction -> m (Operand)
bind t instruction = do
name <- getName
instruct [name := instruction]
return (LocalReference t name)
instance Monad m => MonadLLVM (LLVMT m) where
getName = LLVMT $ do
state <- get
let (name:remaining) = names state
put state{names = remaining}
return name
instruct new = LLVMT $ do
state <- get
put state{instructions = instructions state ++ new}
block name definition = LLVMT $ do
initialState <- get
put initialState{instructions = []}
case instructions initialState of
[] -> return ()
is -> do
preBlockName <- unLLVMT getName
unLLVMT . block preBlockName $ do
instruct is
return . Do $ Br name []
outerState <- get
(term, innerState) <- lift $ runLLVMT definition outerState{blocks = [], instructions = []}
put innerState{
blocks = blocks outerState ++ blocks innerState ++ [BasicBlock name (instructions innerState) term],
instructions = []
}
define t name parameters definition = LLVMT $ do
outerState <- get
(_, innerState) <- lift $ runLLVMT definition outerState{blocks = [], instructions = []}
put innerState {
globals = Map.insert
name
(Function Linkage.External Visibility.Default CallingConvention.C [] t name (parameters, False) [] Nothing 0 Nothing (blocks innerState))
(globals outerState),
blocks = blocks outerState,
instructions = instructions outerState
}
formatInName = Name "formatIn"
formatInType = ArrayType 5 $ IntegerType 8
formatInConstant = Constant.Array (IntegerType 8) (map (Constant.Int 8 . fromIntegral . fromEnum) "%llu\00")
formatIn = ConstantOperand $ Constant.GlobalReference (ptr formatInType) formatInName
formatOutName = Name "formatOut"
formatOutType = ArrayType 6 $ IntegerType 8
formatOutConstant = Constant.Array (IntegerType 8) (map (Constant.Int 8 . fromIntegral . fromEnum) "%llu\n\00")
formatOut = ConstantOperand $ Constant.GlobalReference (ptr formatOutType) formatOutName
getFormatPtr :: MonadLLVM m => Operand -> m (Operand)
getFormatPtr format = bind (ptr (IntegerType 8)) $
GetElementPtr True format [ConstantOperand $ Constant.Int 32 0, ConstantOperand $ Constant.Int 32 0] []
input :: (MonadLLVM m) => PRFCompiled m () Nat
input = BlockLike $ \_ -> RegisterArrow . return $ \_ -> do
dest <- allocPtr $ IntegerType 64
fmt <- getFormatPtr formatIn
instruct [
Do (
Call False CallingConvention.C []
(
Right . ConstantOperand $
Constant.GlobalReference
(FunctionType (IntegerType 32) [ptr (IntegerType 8)] True)
(Name "scanf")
)
[(fmt, []), (dest, [])]
[] []
)
]
r <- bind (IntegerType 64) $ Load False dest Nothing 0 []
return (Registers r)
output :: MonadLLVM m => PRFCompiled m Nat ()
output = BlockLike $ \_ -> RegisterArrow . return $ \(Registers r) -> do
fmt <- getFormatPtr formatOut
instruct [
Do (
Call False CallingConvention.C []
(
Right . ConstantOperand $
Constant.GlobalReference
(FunctionType (IntegerType 32) [ptr (IntegerType 8)] True)
(Name "printf")
)
[(fmt, []), (r, [])]
[] []
)
]
return (Registers ())
-- examples
add :: PrimRec a => a (Nat, Nat) Nat
add = prec id (succ . fst)
match :: PrimRec a => a b c -> a (Nat, b) c -> a (Nat, b) c
match fz fs = prec fz (fs . snd)
one :: PrimRec a => a b Nat
one = succ . zero
nonZero :: PrimRec a => a Nat Nat
nonZero = match zero one . (id &&& id)
isZero :: PrimRec a => a Nat Nat
isZero = match one zero . (id &&& id)
isOdd :: PrimRec a => a Nat Nat
isOdd = prec zero (isZero . fst) . (id &&& id)
--
compile :: PRFCompiled (LLVMT Identity) () () -> Module
compile (BlockLike df) = mod
where
(RegisterArrow mf) = df Dict
nameSource = map (Name . ('n':) . show) $ [1..]
mod = Module "main" Nothing Nothing definitions
definitions = map GlobalDefinition globalDefs
globalDefs = Map.elems (globals finalState) ++ [
GlobalVariable
formatInName Linkage.Internal Visibility.Default False (AddrSpace 0) False True
formatInType (Just formatInConstant) Nothing 0,
GlobalVariable
formatOutName Linkage.Internal Visibility.Default False (AddrSpace 0) False True
formatOutType (Just formatOutConstant) Nothing 0,
Function Linkage.External Visibility.Default CallingConvention.C []
(IntegerType 32) (Name "scanf") ([
Parameter (ptr (IntegerType 8)) (Name "format") []
], True) [] Nothing 0 Nothing [],
Function Linkage.External Visibility.Default CallingConvention.C []
(IntegerType 32) (Name "printf") ([
Parameter (ptr (IntegerType 8)) (Name "format") []
], True) [] Nothing 0 Nothing [],
main
]
main = Function Linkage.External Visibility.Default CallingConvention.C []
(IntegerType 32) (Name "main") ([], False) [] Nothing 0 Nothing [
BasicBlock (Name "mainBlock") (instructions finalState) (Do $ Ret (Just . ConstantOperand . Constant.Int 32 $ 0) [])
]
initialState = LLVMState {globals = Map.empty, blocks = [], instructions = [], names = nameSource}
(f, outerState) = runIdentity . runLLVMT mf $ initialState
(outReg, finalState) = runIdentity . runLLVMT (f (Registers ())) $ outerState
--(regs, innerState) =
compileNatNat :: PRFCompiled (LLVMT Identity) Nat Nat -> Module
compileNatNat p = compile (output . p . input)
main = putStrLn . ppllvm $ compileNatNat isOdd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment