Created
January 11, 2015 05:47
-
-
Save Cedev/89b0ef2b196e25e1cb2c to your computer and use it in GitHub Desktop.
Compiler from ArrowLike primitive recursive functions to LLVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE ConstraintKinds #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE FunctionalDependencies #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE ConstraintKinds #-} | |
import GHC.Exts (Constraint) | |
import Data.Proxy | |
import Data.Word | |
import Control.PrimRec | |
import Prelude hiding (id, (.), fst, snd, succ) | |
import LLVM.General.AST | |
import qualified LLVM.General.AST.Constant as Constant | |
import qualified LLVM.General.AST.IntegerPredicate as ICmp | |
import qualified LLVM.General.AST.Linkage as Linkage | |
import qualified LLVM.General.AST.Visibility as Visibility | |
import qualified LLVM.General.AST.CallingConvention as CallingConvention | |
import LLVM.General.AST.AddrSpace | |
import LLVM.General.PrettyPrint | |
import qualified Data.Map as Map | |
import Control.Monad.Trans.Class | |
import Control.Monad.Trans.State | |
import Data.Functor.Identity | |
import Control.Applicative | |
import LLVM_General_Pure_PrettyPrint | |
data Dict c where | |
Dict :: c => Dict c | |
type RegisterableDict a = Dict (Registerable a, RegisterableCtx a) | |
data Registers r where | |
Registers :: Registerable r => RegisterRep r -> Registers r | |
class Monad m => MonadLLVM m where | |
getName :: m Name | |
instruct :: [Named Instruction] -> m () | |
define :: Type -> Name -> [Parameter] -> m () -> m () | |
block :: Name -> m (Named Terminator) -> m () | |
--branch :: | |
class Registerable a where | |
type RegisterRep a :: * | |
type RegisterableCtx a :: Constraint | |
registerableDict :: Proxy a -> RegisterableDict a | |
types :: Proxy a -> [Type] | |
toOperands :: Registers a -> [Operand] | |
readOperands :: [Operand] -> (Registers a, [Operand]) | |
instance Registerable Nat where | |
type RegisterRep Nat = Operand | |
type RegisterableCtx Nat = () | |
registerableDict _ = Dict | |
types _ = [IntegerType 64] | |
toOperands (Registers o) = [o] | |
readOperands (o:os) = (Registers o, os) | |
instance Registerable () where | |
type RegisterRep () = () | |
type RegisterableCtx () = () | |
registerableDict _ = Dict | |
types _ = [] | |
toOperands _ = [] | |
readOperands os = (Registers (), os) | |
instance (Registerable a, Registerable b) => Registerable (a, b) where | |
type RegisterRep (a, b) = (Registers a, Registers b) | |
type RegisterableCtx (a, b) = (Registerable a, Registerable b) | |
registerableDict _ = Dict | |
types _ = types (Proxy :: Proxy a) ++ types (Proxy :: Proxy b) | |
toOperands (Registers (a, b)) = toOperands a ++ toOperands b | |
readOperands os = let (a, os') = readOperands os | |
(b, os'') = readOperands os' | |
in (Registers (a, b),os'') | |
fromOperands :: Registerable a => [Operand] -> Registers a | |
fromOperands os = let (a, []) = readOperands os in a | |
data RegisterArrow m x y where | |
RegisterArrow :: (Registerable x, Registerable y) => m (Registers x -> m (Registers y)) -> RegisterArrow m x y | |
data PRFCompiled m a b where | |
BlockLike :: (RegisterableDict a -> RegisterArrow m a b) -> PRFCompiled m a b | |
{- | |
:: (Monad m) => (Registers b -> m (Registers c)) -> (Registers a -> m (Registers b)) -> (Registers a -> m (Registers c)) | |
f %%% g = \a -> g a >>= | |
-} | |
rarrowDict :: forall m x y. RegisterArrow m x y -> Dict (Registerable x, Registerable y, RegisterableCtx x, RegisterableCtx y) | |
rarrowDict (RegisterArrow _) = | |
case registerableDict (Proxy :: Proxy x) | |
of Dict -> | |
case registerableDict (Proxy :: Proxy y) | |
of Dict -> Dict | |
fstDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict a | |
fstDict Dict = case registerableDict (Proxy :: Proxy a) of Dict -> Dict | |
sndDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict b | |
sndDict Dict = case registerableDict (Proxy :: Proxy b) of Dict -> Dict | |
instance (Monad m) => Category (PRFCompiled m) where | |
id = BlockLike $ \Dict -> RegisterArrow . return $ return | |
BlockLike df . BlockLike dg = BlockLike $ \Dict -> | |
case dg Dict | |
of rg@(RegisterArrow mg) -> | |
case rarrowDict rg | |
of Dict -> | |
case df Dict | |
of RegisterArrow mf -> RegisterArrow $ do | |
g <- mg | |
f <- mf | |
return (\a -> g a >>= f) | |
instance (Monad m) => ArrowLike (PRFCompiled m) where | |
fst = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (regs, _)) -> return regs | |
snd = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (_, regs)) -> return regs | |
BlockLike df &&& BlockLike dg = BlockLike $ \Dict -> | |
case (df Dict, dg Dict) | |
of (RegisterArrow mf, RegisterArrow mg) -> RegisterArrow $ do | |
f <- mf | |
g <- mg | |
return $ \regs -> do | |
rf <- f regs | |
rg <- g regs | |
return $ Registers (rf, rg) | |
instance (MonadLLVM m) => PrimRec (PRFCompiled m) where | |
zero = BlockLike $ \Dict -> RegisterArrow . return $ \_ -> return . Registers . ConstantOperand . Constant.Int 64 $ 0 | |
succ = BlockLike $ \Dict -> RegisterArrow . return $ regSucc | |
where | |
regSucc (Registers op) = (>>= return . Registers) . opSucc $ op | |
opSucc (ConstantOperand (Constant.Int b v)) = return . ConstantOperand . Constant.Int b $ v+1 | |
opSucc (op) = bind (IntegerType 64) $ Add False False op (ConstantOperand $ Constant.Int 64 1) [] | |
prec (BlockLike df) (BlockLike dg) = BlockLike $ \d@Dict -> | |
case df $ sndDict d | |
of (RegisterArrow mf) -> | |
case dg Dict | |
of (RegisterArrow mg) -> RegisterArrow $ do | |
f <- mf | |
g <- mg | |
defineRecursive $ \go read ret -> do | |
brName <- getName | |
zeroName <- getName | |
succName <- getName | |
rs@(Registers (Registers n, e)) <- read | |
block brName $ do | |
cmp <- bind (IntegerType 1) $ ICmp ICmp.EQ n (ConstantOperand $ Constant.Int 64 0) [] | |
return . Do $ CondBr cmp zeroName succName [] | |
block zeroName $ do | |
c <- f e | |
ret c | |
block succName $ do | |
pred <- bind (IntegerType 64) $ Sub False False n (ConstantOperand $ Constant.Int 64 1) [] | |
c <- go (Registers (Registers pred,e)) | |
c' <- g (Registers (c,rs)) | |
ret c' | |
defineRecursive :: forall x y m. (Registerable x, Registerable y, MonadLLVM m) => | |
( | |
(Registers x -> m (Registers y)) -> -- recursive call | |
m (Registers x) -> -- read parameters | |
(Registers y -> m (Named Terminator)) -> -- return results | |
m () -- function body | |
) -> | |
m (Registers x -> m (Registers y)) -- call function | |
defineRecursive def = do | |
functionName <- getName | |
inPtrName <- getName | |
outPtrName <- getName | |
let | |
inType = StructureType False $ types (Proxy :: Proxy x) | |
outType = StructureType False $ types (Proxy :: Proxy y) | |
call regs = do | |
inPtr <- allocPtr inType | |
outPtr <- allocPtr outType | |
writePtr inPtr regs | |
instruct [ | |
Do ( | |
Call False CallingConvention.C [] | |
( | |
Right . ConstantOperand $ | |
Constant.GlobalReference | |
(FunctionType VoidType [ptr outType, ptr inType] False) | |
functionName | |
) | |
[(outPtr, []), (inPtr, [])] | |
[] [] | |
) | |
] | |
readPtr outPtr | |
ret regs = do | |
writePtr (LocalReference (ptr outType) outPtrName) regs | |
return (Do (Ret Nothing [])) | |
read = readPtr (LocalReference (ptr inType) inPtrName) | |
define VoidType functionName [Parameter (ptr outType) outPtrName [], Parameter (ptr inType) inPtrName []] $ do | |
def call read ret | |
return call | |
ptr x = PointerType x (AddrSpace 0) | |
allocPtr :: (MonadLLVM m) => Type -> m (Operand) | |
allocPtr t = bind (ptr t) $ Alloca t Nothing 0 [] | |
elemPtrs :: (MonadLLVM m) => Operand -> [Type] -> m ([Operand]) | |
elemPtrs struct ts = do | |
sequence $ zipWith getElemPtr ts [0..] | |
where | |
getElemPtr t n = bind (ptr t) $ GetElementPtr True struct [ConstantOperand $ Constant.Int 32 0, ConstantOperand $ Constant.Int 32 n] [] | |
readPtr :: forall r m. (Registerable r, MonadLLVM m) => Operand -> m (Registers r) | |
readPtr struct = do | |
let ts = types (Proxy :: Proxy r) | |
elems <- elemPtrs struct ts | |
ops <- sequence $ zipWith load ts elems | |
return . fromOperands $ ops | |
where | |
load t e = bind t $ Load False e Nothing 0 [] | |
writePtr :: forall r m. (Registerable r, MonadLLVM m) => Operand -> Registers r -> m () | |
writePtr struct regs = do | |
let ops = toOperands regs | |
ts = types (Proxy :: Proxy r) | |
elems <- elemPtrs struct ts | |
sequence_ $ zipWith store elems ops | |
where | |
store e op = instruct [Do $ Store False e op Nothing 0 []] | |
{- | |
newtype FirstDefined k v = FirstDefined (Map.Map k v) | |
instance (Ord k) => Monoid (FirstDefined k v) where | |
mempty = FirstDefined Map.empty | |
(FirstDefined a) `mappend` (FirstDefined b) = FirstDefined (Map.union a b) | |
-} | |
newtype LLVMT m a = LLVMT {unLLVMT :: StateT LLVMState m a} | |
deriving (Functor, Applicative, Monad) | |
data LLVMState = LLVMState { | |
globals :: Map.Map Name Global, | |
blocks :: [BasicBlock], | |
instructions :: [Named Instruction], | |
names :: [Name] | |
} | |
runLLVMT :: LLVMT m a -> LLVMState -> m (a, LLVMState) | |
runLLVMT = runStateT . unLLVMT | |
bind :: (MonadLLVM m) => Type -> Instruction -> m (Operand) | |
bind t instruction = do | |
name <- getName | |
instruct [name := instruction] | |
return (LocalReference t name) | |
instance Monad m => MonadLLVM (LLVMT m) where | |
getName = LLVMT $ do | |
state <- get | |
let (name:remaining) = names state | |
put state{names = remaining} | |
return name | |
instruct new = LLVMT $ do | |
state <- get | |
put state{instructions = instructions state ++ new} | |
block name definition = LLVMT $ do | |
initialState <- get | |
put initialState{instructions = []} | |
case instructions initialState of | |
[] -> return () | |
is -> do | |
preBlockName <- unLLVMT getName | |
unLLVMT . block preBlockName $ do | |
instruct is | |
return . Do $ Br name [] | |
outerState <- get | |
(term, innerState) <- lift $ runLLVMT definition outerState{blocks = [], instructions = []} | |
put innerState{ | |
blocks = blocks outerState ++ blocks innerState ++ [BasicBlock name (instructions innerState) term], | |
instructions = [] | |
} | |
define t name parameters definition = LLVMT $ do | |
outerState <- get | |
(_, innerState) <- lift $ runLLVMT definition outerState{blocks = [], instructions = []} | |
put innerState { | |
globals = Map.insert | |
name | |
(Function Linkage.External Visibility.Default CallingConvention.C [] t name (parameters, False) [] Nothing 0 Nothing (blocks innerState)) | |
(globals outerState), | |
blocks = blocks outerState, | |
instructions = instructions outerState | |
} | |
formatInName = Name "formatIn" | |
formatInType = ArrayType 5 $ IntegerType 8 | |
formatInConstant = Constant.Array (IntegerType 8) (map (Constant.Int 8 . fromIntegral . fromEnum) "%llu\00") | |
formatIn = ConstantOperand $ Constant.GlobalReference (ptr formatInType) formatInName | |
formatOutName = Name "formatOut" | |
formatOutType = ArrayType 6 $ IntegerType 8 | |
formatOutConstant = Constant.Array (IntegerType 8) (map (Constant.Int 8 . fromIntegral . fromEnum) "%llu\n\00") | |
formatOut = ConstantOperand $ Constant.GlobalReference (ptr formatOutType) formatOutName | |
getFormatPtr :: MonadLLVM m => Operand -> m (Operand) | |
getFormatPtr format = bind (ptr (IntegerType 8)) $ | |
GetElementPtr True format [ConstantOperand $ Constant.Int 32 0, ConstantOperand $ Constant.Int 32 0] [] | |
input :: (MonadLLVM m) => PRFCompiled m () Nat | |
input = BlockLike $ \_ -> RegisterArrow . return $ \_ -> do | |
dest <- allocPtr $ IntegerType 64 | |
fmt <- getFormatPtr formatIn | |
instruct [ | |
Do ( | |
Call False CallingConvention.C [] | |
( | |
Right . ConstantOperand $ | |
Constant.GlobalReference | |
(FunctionType (IntegerType 32) [ptr (IntegerType 8)] True) | |
(Name "scanf") | |
) | |
[(fmt, []), (dest, [])] | |
[] [] | |
) | |
] | |
r <- bind (IntegerType 64) $ Load False dest Nothing 0 [] | |
return (Registers r) | |
output :: MonadLLVM m => PRFCompiled m Nat () | |
output = BlockLike $ \_ -> RegisterArrow . return $ \(Registers r) -> do | |
fmt <- getFormatPtr formatOut | |
instruct [ | |
Do ( | |
Call False CallingConvention.C [] | |
( | |
Right . ConstantOperand $ | |
Constant.GlobalReference | |
(FunctionType (IntegerType 32) [ptr (IntegerType 8)] True) | |
(Name "printf") | |
) | |
[(fmt, []), (r, [])] | |
[] [] | |
) | |
] | |
return (Registers ()) | |
-- examples | |
add :: PrimRec a => a (Nat, Nat) Nat | |
add = prec id (succ . fst) | |
match :: PrimRec a => a b c -> a (Nat, b) c -> a (Nat, b) c | |
match fz fs = prec fz (fs . snd) | |
one :: PrimRec a => a b Nat | |
one = succ . zero | |
nonZero :: PrimRec a => a Nat Nat | |
nonZero = match zero one . (id &&& id) | |
isZero :: PrimRec a => a Nat Nat | |
isZero = match one zero . (id &&& id) | |
isOdd :: PrimRec a => a Nat Nat | |
isOdd = prec zero (isZero . fst) . (id &&& id) | |
-- | |
compile :: PRFCompiled (LLVMT Identity) () () -> Module | |
compile (BlockLike df) = mod | |
where | |
(RegisterArrow mf) = df Dict | |
nameSource = map (Name . ('n':) . show) $ [1..] | |
mod = Module "main" Nothing Nothing definitions | |
definitions = map GlobalDefinition globalDefs | |
globalDefs = Map.elems (globals finalState) ++ [ | |
GlobalVariable | |
formatInName Linkage.Internal Visibility.Default False (AddrSpace 0) False True | |
formatInType (Just formatInConstant) Nothing 0, | |
GlobalVariable | |
formatOutName Linkage.Internal Visibility.Default False (AddrSpace 0) False True | |
formatOutType (Just formatOutConstant) Nothing 0, | |
Function Linkage.External Visibility.Default CallingConvention.C [] | |
(IntegerType 32) (Name "scanf") ([ | |
Parameter (ptr (IntegerType 8)) (Name "format") [] | |
], True) [] Nothing 0 Nothing [], | |
Function Linkage.External Visibility.Default CallingConvention.C [] | |
(IntegerType 32) (Name "printf") ([ | |
Parameter (ptr (IntegerType 8)) (Name "format") [] | |
], True) [] Nothing 0 Nothing [], | |
main | |
] | |
main = Function Linkage.External Visibility.Default CallingConvention.C [] | |
(IntegerType 32) (Name "main") ([], False) [] Nothing 0 Nothing [ | |
BasicBlock (Name "mainBlock") (instructions finalState) (Do $ Ret (Just . ConstantOperand . Constant.Int 32 $ 0) []) | |
] | |
initialState = LLVMState {globals = Map.empty, blocks = [], instructions = [], names = nameSource} | |
(f, outerState) = runIdentity . runLLVMT mf $ initialState | |
(outReg, finalState) = runIdentity . runLLVMT (f (Registers ())) $ outerState | |
--(regs, innerState) = | |
compileNatNat :: PRFCompiled (LLVMT Identity) Nat Nat -> Module | |
compileNatNat p = compile (output . p . input) | |
main = putStrLn . ppllvm $ compileNatNat isOdd |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment