Created
April 12, 2017 02:17
-
-
Save lexi-lambda/045ba782c8a0d915bd8abf97167d3bb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleInstances #-} | |
-- | An implementation of Section 3, Local Type Argument Synthesis, from the | |
-- paper /Local Type Inference/ by Pierce and Turner. | |
module Infer where | |
import Control.Monad (foldM, join, zipWithM) | |
import Data.Function (on) | |
import Data.List (foldl', groupBy, intercalate, intersect, nub) | |
type TVar = String | |
type EVar = String | |
class Typeset a where | |
typeset :: a -> String | |
printTypeset :: Typeset a => a -> IO () | |
printTypeset = putStrLn . typeset | |
instance Typeset a => Typeset [a] where | |
typeset xs = "[" ++ intercalate ", " (map typeset xs) ++ "]" | |
instance Typeset b => Typeset (Either String b) where | |
typeset (Left x) = "*** " ++ x | |
typeset (Right x) = typeset x | |
data Type | |
= Top | |
| Bot | |
| TVar TVar | |
| TAll [TVar] [Type] Type | |
deriving (Eq, Show) | |
instance Typeset Type where | |
typeset Top = "Top" | |
typeset Bot = "Bot" | |
typeset (TVar v) = v | |
typeset (TAll vs ts t) = "All(" ++ intercalate ", " vs ++ ")(" ++ intercalate ", " (map typeset ts) ++ ") → " ++ typeset t | |
data Expr | |
= EVar EVar | |
| EFn [TVar] [(EVar, Type)] Expr | |
| EApp Expr [Type] [Expr] | |
deriving (Eq, Show) | |
typesetBindings :: [(EVar, Type)] -> String | |
typesetBindings = intercalate ", " . map (\(x, t) -> x ++ " : " ++ typeset t) | |
instance Typeset Expr where | |
typeset (EVar v) = v | |
typeset (EFn vs xts e) = "fun[" ++ intercalate ", " vs ++ "](" ++ typesetBindings xts ++ ")" ++ typeset e | |
typeset (EApp e ts xs) = typeset e ++ "[" ++ intercalate ", " (map typeset ts) ++ "](" ++ intercalate ", " (map typeset xs) ++ ")" | |
data Ctx = Ctx [(EVar, Type)] [TVar] | |
deriving (Eq, Show) | |
data Constraint = Constraint Type TVar Type | |
deriving (Eq, Show) | |
instance Typeset Constraint where | |
typeset (Constraint a b c) = typeset a ++ " <: " ++ b ++ " <: " ++ typeset c | |
type Subst = (TVar, Type) | |
data Variance = Covariant | Contravariant | Invariant | Bivariant | |
deriving (Eq, Show) | |
instance Monoid Variance where | |
mempty = Bivariant | |
mappend Bivariant x = x | |
mappend x Bivariant = x | |
mappend Covariant Covariant = Covariant | |
mappend Contravariant Contravariant = Contravariant | |
mappend _ _ = Invariant | |
boundU :: Type -> Type -> Either String Type | |
boundU _ Top = return Top | |
boundU Top _ = return Top | |
boundU t Bot = return t | |
boundU Bot t = return t | |
boundU (TAll xs rs s) (TAll ys ts u) | xs == ys = TAll xs <$> zipWithM boundL rs ts <*> boundU s u | |
boundU t s | |
| t == s = return t | |
| otherwise = Left $ "type mismatch between ‘" ++ typeset t ++ "’ and ‘" ++ typeset s ++ "’" | |
boundL :: Type -> Type -> Either String Type | |
boundL t Top = return t | |
boundL Top t = return t | |
boundL _ Bot = return Bot | |
boundL Bot _ = return Bot | |
boundL (TAll xs rs s) (TAll ys ts u) | xs == ys = TAll xs <$> zipWithM boundU rs ts <*> boundL s u | |
boundL t s | |
| t == s = return t | |
| otherwise = Left $ "type mismatch between ‘" ++ typeset t ++ "’ and ‘" ++ typeset s ++ "’" | |
elimU :: [TVar] -> Type -> Type | |
elimU _ Top = Top | |
elimU _ Bot = Bot | |
elimU vs t@(TVar v) | |
| v `elem` vs = Top | |
| otherwise = t | |
elimU vs (TAll xs ts_in t_out) = | |
let ts_in' = map (elimD vs) ts_in | |
t_out' = elimU vs t_out | |
in TAll xs ts_in' t_out' | |
elimD :: [TVar] -> Type -> Type | |
elimD _ Top = Top | |
elimD _ Bot = Bot | |
elimD vs t@(TVar v) | |
| v `elem` vs = Bot | |
| otherwise = t | |
elimD vs (TAll xs ts_in t_out) = | |
let ts_in' = map (elimU vs) ts_in | |
t_out' = elimD vs t_out | |
in TAll xs ts_in' t_out' | |
genConstraints :: [TVar] -> [TVar] -> Type -> Type -> Either String [Constraint] | |
genConstraints _ _ _ Top = return [] -- CG-Top | |
genConstraints _ _ Bot _ = return [] -- CG-Bot | |
genConstraints _ _ y s | y == s = return [] -- CG-Refl | |
genConstraints vs xs (TVar v) s | v `elem` xs = -- CG-Upper | |
let t = elimD vs s in return [Constraint Bot v t] | |
genConstraints vs xs s (TVar v) | v `elem` xs = -- CG-Lower | |
let t = elimU vs s in return [Constraint t v Top] | |
genConstraints vs xs (TAll ys rs s) (TAll ys' ts u) -- CG-Fun | |
| null (ys `intersect` vs) | |
&& null (ys `intersect` xs) | |
&& ys == ys' | |
= do let vs' = nub (vs ++ ys) | |
cs <- zipWithM (genConstraints vs' xs) ts rs | |
d <- genConstraints vs' xs s u | |
return $ nub (d ++ join cs) | |
genConstraints vs xs y s = | |
Left $ "genConstraints: failed to generate constraints for {" | |
++ intercalate ", " vs ++ "} ⊢_{" ++ intercalate ", " xs ++ "} " | |
++ typeset y ++ " <: " ++ typeset s | |
reduceConstraints :: [Constraint] -> Either String [Constraint] | |
reduceConstraints = traverse concatCs . groupCs | |
where groupCs :: [Constraint] -> [[Constraint]] | |
groupCs = groupBy ((==) `on` \(Constraint _ x _) -> x) | |
appendCs :: Constraint -> Constraint -> Either String Constraint | |
appendCs (Constraint s x t) (Constraint u _ v) = Constraint <$> boundU s u <*> pure x <*> boundL t v | |
concatCs :: [Constraint] -> Either String Constraint | |
concatCs (c:cs) = foldM appendCs c cs | |
concatCs [] = error "concatCs: internal error" | |
invertVariance :: Variance -> Variance | |
invertVariance Covariant = Contravariant | |
invertVariance Contravariant = Covariant | |
invertVariance Invariant = Invariant | |
invertVariance Bivariant = Bivariant | |
variance :: TVar -> Type -> Variance | |
variance _ Top = Bivariant | |
variance _ Bot = Bivariant | |
variance x (TVar y) | |
| x == y = Covariant | |
| otherwise = Bivariant | |
variance x (TAll _ ys y) = mconcat (invertVariance (variance x y) : map (variance x) ys) | |
minimumSubst :: Type -> [Constraint] -> Maybe [Subst] | |
minimumSubst r = traverse single | |
where single :: Constraint -> Maybe Subst | |
single (Constraint s x t) = case variance x r of | |
Bivariant -> Just (x, s) | |
Covariant -> Just (x, s) | |
Contravariant -> Just (x, t) | |
Invariant | |
| s == t -> Just (x, s) | |
| otherwise -> Nothing | |
applySubst :: [Subst] -> Type -> Type | |
applySubst substs r = foldl' (flip single) r substs | |
where single subst@(x, t') t = case t of | |
Top -> Top | |
Bot -> Bot | |
s@(TVar y) | |
| x == y -> t' | |
| otherwise -> s | |
TAll xs ss s -> TAll xs (map (single subst) ss) (single subst s) | |
inferApp :: Type -> [Type] -> Either String Type | |
inferApp t@(TAll xs ts r) ss = do | |
cs <- concat <$> zipWithM (genConstraints [] xs) ss ts | |
cs' <- reduceConstraints cs | |
subst <- maybe (Left $ "inferApp: could not infer type arguments for " ++ typeset t ++ " applied to types " ++ intercalate ", " (map typeset ss)) | |
Right (minimumSubst r cs') | |
return $ applySubst subst r | |
inferApp t _ = Left $ "inferApp: expected function type, given " ++ typeset t |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment