Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
{-# LANGUAGE FlexibleInstances #-}
-- | An implementation of Section 3, Local Type Argument Synthesis, from the
-- paper /Local Type Inference/ by Pierce and Turner.
module Infer where
import Control.Monad (foldM, join, zipWithM)
import Data.Function (on)
import Data.List (foldl', groupBy, intercalate, intersect, nub)
type TVar = String
type EVar = String
class Typeset a where
typeset :: a -> String
printTypeset :: Typeset a => a -> IO ()
printTypeset = putStrLn . typeset
instance Typeset a => Typeset [a] where
typeset xs = "[" ++ intercalate ", " (map typeset xs) ++ "]"
instance Typeset b => Typeset (Either String b) where
typeset (Left x) = "*** " ++ x
typeset (Right x) = typeset x
data Type
= Top
| Bot
| TVar TVar
| TAll [TVar] [Type] Type
deriving (Eq, Show)
instance Typeset Type where
typeset Top = "Top"
typeset Bot = "Bot"
typeset (TVar v) = v
typeset (TAll vs ts t) = "All(" ++ intercalate ", " vs ++ ")(" ++ intercalate ", " (map typeset ts) ++ ") → " ++ typeset t
data Expr
= EVar EVar
| EFn [TVar] [(EVar, Type)] Expr
| EApp Expr [Type] [Expr]
deriving (Eq, Show)
typesetBindings :: [(EVar, Type)] -> String
typesetBindings = intercalate ", " . map (\(x, t) -> x ++ " : " ++ typeset t)
instance Typeset Expr where
typeset (EVar v) = v
typeset (EFn vs xts e) = "fun[" ++ intercalate ", " vs ++ "](" ++ typesetBindings xts ++ ")" ++ typeset e
typeset (EApp e ts xs) = typeset e ++ "[" ++ intercalate ", " (map typeset ts) ++ "](" ++ intercalate ", " (map typeset xs) ++ ")"
data Ctx = Ctx [(EVar, Type)] [TVar]
deriving (Eq, Show)
data Constraint = Constraint Type TVar Type
deriving (Eq, Show)
instance Typeset Constraint where
typeset (Constraint a b c) = typeset a ++ " <: " ++ b ++ " <: " ++ typeset c
type Subst = (TVar, Type)
data Variance = Covariant | Contravariant | Invariant | Bivariant
deriving (Eq, Show)
instance Monoid Variance where
mempty = Bivariant
mappend Bivariant x = x
mappend x Bivariant = x
mappend Covariant Covariant = Covariant
mappend Contravariant Contravariant = Contravariant
mappend _ _ = Invariant
boundU :: Type -> Type -> Either String Type
boundU _ Top = return Top
boundU Top _ = return Top
boundU t Bot = return t
boundU Bot t = return t
boundU (TAll xs rs s) (TAll ys ts u) | xs == ys = TAll xs <$> zipWithM boundL rs ts <*> boundU s u
boundU t s
| t == s = return t
| otherwise = Left $ "type mismatch between ‘" ++ typeset t ++ "’ and ‘" ++ typeset s ++ ""
boundL :: Type -> Type -> Either String Type
boundL t Top = return t
boundL Top t = return t
boundL _ Bot = return Bot
boundL Bot _ = return Bot
boundL (TAll xs rs s) (TAll ys ts u) | xs == ys = TAll xs <$> zipWithM boundU rs ts <*> boundL s u
boundL t s
| t == s = return t
| otherwise = Left $ "type mismatch between ‘" ++ typeset t ++ "’ and ‘" ++ typeset s ++ ""
elimU :: [TVar] -> Type -> Type
elimU _ Top = Top
elimU _ Bot = Bot
elimU vs t@(TVar v)
| v `elem` vs = Top
| otherwise = t
elimU vs (TAll xs ts_in t_out) =
let ts_in' = map (elimD vs) ts_in
t_out' = elimU vs t_out
in TAll xs ts_in' t_out'
elimD :: [TVar] -> Type -> Type
elimD _ Top = Top
elimD _ Bot = Bot
elimD vs t@(TVar v)
| v `elem` vs = Bot
| otherwise = t
elimD vs (TAll xs ts_in t_out) =
let ts_in' = map (elimU vs) ts_in
t_out' = elimD vs t_out
in TAll xs ts_in' t_out'
genConstraints :: [TVar] -> [TVar] -> Type -> Type -> Either String [Constraint]
genConstraints _ _ _ Top = return [] -- CG-Top
genConstraints _ _ Bot _ = return [] -- CG-Bot
genConstraints _ _ y s | y == s = return [] -- CG-Refl
genConstraints vs xs (TVar v) s | v `elem` xs = -- CG-Upper
let t = elimD vs s in return [Constraint Bot v t]
genConstraints vs xs s (TVar v) | v `elem` xs = -- CG-Lower
let t = elimU vs s in return [Constraint t v Top]
genConstraints vs xs (TAll ys rs s) (TAll ys' ts u) -- CG-Fun
| null (ys `intersect` vs)
&& null (ys `intersect` xs)
&& ys == ys'
= do let vs' = nub (vs ++ ys)
cs <- zipWithM (genConstraints vs' xs) ts rs
d <- genConstraints vs' xs s u
return $ nub (d ++ join cs)
genConstraints vs xs y s =
Left $ "genConstraints: failed to generate constraints for {"
++ intercalate ", " vs ++ "} ⊢_{" ++ intercalate ", " xs ++ "} "
++ typeset y ++ " <: " ++ typeset s
reduceConstraints :: [Constraint] -> Either String [Constraint]
reduceConstraints = traverse concatCs . groupCs
where groupCs :: [Constraint] -> [[Constraint]]
groupCs = groupBy ((==) `on` \(Constraint _ x _) -> x)
appendCs :: Constraint -> Constraint -> Either String Constraint
appendCs (Constraint s x t) (Constraint u _ v) = Constraint <$> boundU s u <*> pure x <*> boundL t v
concatCs :: [Constraint] -> Either String Constraint
concatCs (c:cs) = foldM appendCs c cs
concatCs [] = error "concatCs: internal error"
invertVariance :: Variance -> Variance
invertVariance Covariant = Contravariant
invertVariance Contravariant = Covariant
invertVariance Invariant = Invariant
invertVariance Bivariant = Bivariant
variance :: TVar -> Type -> Variance
variance _ Top = Bivariant
variance _ Bot = Bivariant
variance x (TVar y)
| x == y = Covariant
| otherwise = Bivariant
variance x (TAll _ ys y) = mconcat (invertVariance (variance x y) : map (variance x) ys)
minimumSubst :: Type -> [Constraint] -> Maybe [Subst]
minimumSubst r = traverse single
where single :: Constraint -> Maybe Subst
single (Constraint s x t) = case variance x r of
Bivariant -> Just (x, s)
Covariant -> Just (x, s)
Contravariant -> Just (x, t)
Invariant
| s == t -> Just (x, s)
| otherwise -> Nothing
applySubst :: [Subst] -> Type -> Type
applySubst substs r = foldl' (flip single) r substs
where single subst@(x, t') t = case t of
Top -> Top
Bot -> Bot
s@(TVar y)
| x == y -> t'
| otherwise -> s
TAll xs ss s -> TAll xs (map (single subst) ss) (single subst s)
inferApp :: Type -> [Type] -> Either String Type
inferApp t@(TAll xs ts r) ss = do
cs <- concat <$> zipWithM (genConstraints [] xs) ss ts
cs' <- reduceConstraints cs
subst <- maybe (Left $ "inferApp: could not infer type arguments for " ++ typeset t ++ " applied to types " ++ intercalate ", " (map typeset ss))
Right (minimumSubst r cs')
return $ applySubst subst r
inferApp t _ = Left $ "inferApp: expected function type, given " ++ typeset t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment