Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
diff --git a/compiler/typecheck/TcInteract.lhs b/compiler/typecheck/TcInteract.lhs
index 377cd2d..a742f96 100644
--- a/compiler/typecheck/TcInteract.lhs
+++ b/compiler/typecheck/TcInteract.lhs
@@ -31,6 +31,8 @@ import FamInstEnv ( FamInstEnvs, instNewTyConTF_maybe )
import TcEvidence
import Outputable
+import TcTypeNats ( evBySOP, sopRelevantTyCon, sopToProp, sopToExpr, solveSOP )
+
import TcRnTypes
import TcErrors
import TcSMonad
@@ -526,14 +528,21 @@ interactFunEq inerts workItem@(CFunEqCan { cc_ev = ev, cc_fun = tc
; traceTcS "builtInCandidates: " $ ppr is
; let interact = sfInteractInert ops args rhs
; impMbs <- sequence
- [ do mb <- newDerived (ctev_loc iev) (mkTcEqPred lhs_ty rhs_ty)
+ [ do mb <- if isGiven ev && isGiven iev
+ then do let newEv = EvCoercion $ mkTcAxiomRuleCo coAxR
+ (args ++ iargs)
+ [ evTermCoercion (ctev_evtm ev)
+ , evTermCoercion (ctev_evtm iev)]
+ d <- newGivenEvVar (ctev_loc iev) (mkTcEqPred lhs_ty rhs_ty,newEv)
+ return (Just d)
+ else newDerived (ctev_loc iev) (mkTcEqPred lhs_ty rhs_ty)
case mb of
Just x -> return $ Just $ mkNonCanonical x
Nothing -> return Nothing
| CFunEqCan { cc_tyargs = iargs
, cc_rhs = ixi
, cc_ev = iev } <- is
- , Pair lhs_ty rhs_ty <- interact iargs ixi
+ , (coAxR, Pair lhs_ty rhs_ty) <- interact iargs ixi
]
; let imps = catMaybes impMbs
; unless (null imps) $ updWorkListTcS (extendWorkListEqs imps)
@@ -1498,7 +1507,7 @@ doTopReactFunEq _ct fl fun_tc args xi
-- Look up in top-level instances, or built-in axiom
do { match_res <- matchFam fun_tc args -- See Note [MATCHING-SYNONYMS]
; case match_res of {
- Nothing -> do { try_improvement; return NoTopInt } ;
+ Nothing -> do { try_improvement_and_return } ;
Just (co, ty) ->
-- Found a top-level instance
@@ -1509,14 +1518,23 @@ doTopReactFunEq _ct fl fun_tc args xi
where
loc = ctev_loc fl
- try_improvement
+ try_improvement_and_return
| Just ops <- isBuiltInSynFamTyCon_maybe fun_tc
- = do { let eqns = sfInteractTop ops args xi
- ; impsMb <- mapM (\(Pair x y) -> newDerived loc (mkTcEqPred x y)) eqns
- ; let work = map mkNonCanonical (catMaybes impsMb)
- ; unless (null work) (updWorkListTcS (extendWorkListEqs work)) }
+ = do { res <- if sopRelevantTyCon fun_tc
+ then do (gis,wis) <- getRelevantCtsForSOP
+ decideEqualSOP gis wis _ct
+ else return NoTopInt
+ ; case res of { NoTopInt -> (do { let eqns = sfInteractTop ops args xi
+ ; impsMb <- mapM (\(Pair x y) -> newDerived loc (mkTcEqPred x y)) eqns
+ ; let work = map mkNonCanonical (catMaybes impsMb)
+ ; unless (null work) (updWorkListTcS (extendWorkListEqs work))
+ })
+ ; _ -> return ()
+ }
+ ; return res
+ }
| otherwise
- = return ()
+ = return NoTopInt
succeed_with :: String -> TcCoercion -> TcType -> TcS TopInteractResult
succeed_with str co rhs_ty -- co :: fun_tc args ~ rhs_ty
@@ -1534,6 +1552,56 @@ doTopReactFunEq _ct fl fun_tc args xi
xcomp [x] = EvCoercion (co `mkTcTransCo` evTermCoercion x)
xcomp _ = panic "No more goals!"
xev = XEvTerm [mkTcEqPred rhs_ty xi] xcomp xdecomp
+
+decideEqualSOP :: [Ct] -> [Ct] -> Ct -> TcS TopInteractResult
+decideEqualSOP givenCts wantedCts goalCt = do
+ traceTcS "decideEqualSOP" $ text "given: " <+> ppr givenCts <+>
+ text ", wanted: " <+> ppr wantedCts <+>
+ text ", goal: " <+> ppr goalCt <+>
+ text ", res: " <+> text (show proved)
+ case proved of
+ Just True -> do when (isWantedCt goalCt) (solve goalCt)
+ done Stop
+ Just False -> do emitInsoluble goalCt
+ done Stop
+ Nothing -> do return NoTopInt
+ where
+ assumptions
+ | isWantedCt goalCt = givens ++ wanteds
+ | otherwise = givens
+
+ givens = map ctToTerm givenCts
+ wanteds = map ctToTerm wantedCts
+ goal = ctToTerm goalCt
+
+ ctToTerm CFunEqCan { cc_fun = tc, cc_tyargs = ts, cc_rhs = xi} =
+ sopToProp tc (map sopToExpr ts) (sopToExpr xi)
+ ctToTerm ct = pprPanic "ctToTerm" (text "SOP" <+> ppr ct)
+
+ proved = solveSOP assumptions goal
+
+ solve :: Ct -> TcS ()
+ solve ct = setEvBind (ctEvId $ ctEvidence ct)
+ $ uncurry evBySOP $ getEqPredTys $ ctPred ct
+
+ done :: StopOrContinue -> TcS TopInteractResult
+ done x = return (SomeTopInt "SOP" x)
+
+-- This gets givens and wanteds relevant for SOP equality
+getRelevantCtsForSOP :: TcS ([Ct], [Ct])
+getRelevantCtsForSOP = do
+ inerts <- getTcSInerts
+ let cans = inert_cans inerts
+ feqs = inert_funeqs cans
+ (wanteds, rest_funeqs) = partitionFunEqs (consider isWantedCt) feqs
+ givens = fst $ partitionFunEqs (consider isGivenCt) rest_funeqs
+ return (bagToList givens,bagToList wanteds)
+ where
+ consider :: (Ct -> Bool) -> Ct -> Bool
+ consider fl ct
+ | fl ct
+ , Just (tc,_) <- isCFunEqCan_maybe ct = sopRelevantTyCon tc
+ | otherwise = False
\end{code}
Note [Cached solved FunEqs]
diff --git a/compiler/typecheck/TcSMonad.lhs b/compiler/typecheck/TcSMonad.lhs
index 634e926..2e9225f 100644
--- a/compiler/typecheck/TcSMonad.lhs
+++ b/compiler/typecheck/TcSMonad.lhs
@@ -38,6 +38,7 @@ module TcSMonad (
deferTcSForAllEq,
setEvBind,
+ newGivenEvVar,
XEvTerm(..),
MaybeNew (..), isFresh, freshGoal, freshGoals, getEvTerm, getEvTerms,
diff --git a/compiler/typecheck/TcTypeNats.hs b/compiler/typecheck/TcTypeNats.hs
index c19164b..5050f01 100644
--- a/compiler/typecheck/TcTypeNats.hs
+++ b/compiler/typecheck/TcTypeNats.hs
@@ -2,31 +2,44 @@ module TcTypeNats
( typeNatTyCons
, typeNatCoAxiomRules
, BuiltInSynFamily(..)
+ , sopRelevantTyCon
+ , sopToProp
+ , sopToExpr
+ , solveSOP
+ , evBySOP
) where
import Type
import Pair
-import TcType ( TcType, tcEqType )
-import TyCon ( TyCon, SynTyConRhs(..), mkSynTyCon, TyConParent(..) )
-import Coercion ( Role(..) )
-import TcRnTypes ( Xi )
-import CoAxiom ( CoAxiomRule(..), BuiltInSynFamily(..) )
-import Name ( Name, BuiltInSyntax(..) )
-import TysWiredIn ( typeNatKind, mkWiredInTyConName
- , promotedBoolTyCon
- , promotedFalseDataCon, promotedTrueDataCon
- )
-import TysPrim ( tyVarList, mkArrowKinds )
-import PrelNames ( gHC_TYPELITS
- , typeNatAddTyFamNameKey
- , typeNatMulTyFamNameKey
- , typeNatExpTyFamNameKey
- , typeNatLeqTyFamNameKey
- , typeNatSubTyFamNameKey
- )
-import FastString ( FastString, fsLit )
+import TcType ( TcType, tcEqType )
+import TyCon ( TyCon, SynTyConRhs(..), mkSynTyCon, TyConParent(..) )
+import Coercion ( Role(..) )
+import TcEvidence ( mkTcAxiomRuleCo, EvTerm(..) )
+import TcRnTypes ( Xi )
+import CoAxiom ( CoAxiomRule(..), BuiltInSynFamily(..) )
+import Name ( Name, BuiltInSyntax(..) )
+import TysWiredIn ( typeNatKind, mkWiredInTyConName
+ , promotedBoolTyCon
+ , promotedFalseDataCon, promotedTrueDataCon
+ )
+import TysPrim ( tyVarList, mkArrowKinds )
+import PrelNames ( gHC_TYPELITS
+ , typeNatAddTyFamNameKey
+ , typeNatMulTyFamNameKey
+ , typeNatExpTyFamNameKey
+ , typeNatLeqTyFamNameKey
+ , typeNatSubTyFamNameKey
+ )
+import FastString ( FastString, fsLit )
+import Outputable ( Outputable, ppr, pprPanic, (<+>), text, int , integer
+ , parens, hcat, punctuate )
+import UniqSet ( UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
+import Data.Either ( partitionEithers )
+import Data.Function ( on )
+import Data.List ( sort, (\\) )
import qualified Data.Map as Map
-import Data.Maybe ( isJust )
+import Data.Maybe ( catMaybes, isJust )
+import Control.Arrow ( (***), second )
{-------------------------------------------------------------------------------
Built-in type constructors for functions on type-lelve nats
@@ -140,6 +153,17 @@ axAddDef
, axLeq0L
, axSubDef
, axSub0R
+ , axAddCancelL
+ , axAddCancelR
+ , axSubCancelL
+ , axSubCancelR
+ , axMulCancelL
+ , axMulCancelR
+ , axExpCancelL
+ , axExpCancelR
+ , axLeqAsym
+ , axLeqTrans
+ , axLeqTransSym
:: CoAxiomRule
axAddDef = mkBinAxiom "AddDef" typeNatAddTyCon $
@@ -170,6 +194,22 @@ axExp1R = mkAxiom1 "Exp1R" $ \t -> (t .^. num 1) === t
axLeqRefl = mkAxiom1 "LeqRefl" $ \t -> (t <== t) === bool True
axLeq0L = mkAxiom1 "Leq0L" $ \t -> (num 0 <== t) === bool True
+axAddCancelL = mkImpAxiom4 "AddCancelL" $ \([_,b,c,_]) -> b === c
+axAddCancelR = mkImpAxiom4 "AddCancelR" $ \([a,b,_,_]) -> a === b
+
+axSubCancelL = mkImpAxiom4 "SubCancelL" $ \([_,b,c,_]) -> b === c
+axSubCancelR = mkImpAxiom4 "SubCancelR" $ \([a,b,_,_]) -> a === b
+
+axMulCancelL = mkImpAxiom4 "MulCancelL" $ \([_,b,c,_]) -> b === c
+axMulCancelR = mkImpAxiom4 "MulCancelR" $ \([a,b,_,_]) -> a === b
+
+axExpCancelL = mkImpAxiom4 "ExpCancelL" $ \([_,b,c,_]) -> b === c
+axExpCancelR = mkImpAxiom4 "ExpCancelR" $ \([a,b,_,_]) -> a === b
+
+axLeqAsym = mkImpAxiom4 "LeqAsym" $ \([a,b,_,_]) -> a === b
+axLeqTrans = mkImpAxiom4 "LeqTrans" $ \([a,_,c,_]) -> (a <== c) === bool True
+axLeqTransSym = mkImpAxiom4 "LeqTransSym" $ \([_,b,c,_]) -> (c <== b) === bool True
+
typeNatCoAxiomRules :: Map.Map FastString CoAxiomRule
typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
[ axAddDef
@@ -188,9 +228,35 @@ typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
, axLeqRefl
, axLeq0L
, axSubDef
+ , eqSOP
+ , axAddCancelL
+ , axAddCancelR
+ , axSubCancelL
+ , axSubCancelR
+ , axMulCancelL
+ , axMulCancelR
+ , axExpCancelL
+ , axExpCancelR
+ , axLeqAsym
+ , axLeqTrans
+ , axLeqTransSym
]
+eqSOP :: CoAxiomRule
+eqSOP =
+ CoAxiomRule
+ { coaxrName = fsLit "EqSOP"
+ , coaxrTypeArity = 2
+ , coaxrAsmpRoles = []
+ , coaxrRole = Nominal
+ , coaxrProves = \ts cs ->
+ case (ts,cs) of
+ ([s,t],[]) -> return (s === t)
+ _ -> Nothing
+ }
+evBySOP :: Type -> Type -> EvTerm
+evBySOP t1 t2 = EvCoercion $ mkTcAxiomRuleCo eqSOP [t1,t2] []
{-------------------------------------------------------------------------------
Various utilities for making axioms and types
@@ -271,6 +337,19 @@ mkAxiom1 str f =
_ -> Nothing
}
+mkImpAxiom4 :: String -> ([Type] -> Pair Type) -> CoAxiomRule
+mkImpAxiom4 str f =
+ CoAxiomRule
+ { coaxrName = fsLit str
+ , coaxrTypeArity = 6
+ , coaxrAsmpRoles = [Nominal,Nominal]
+ , coaxrRole = Nominal
+ , coaxrProves = \ts cs ->
+ case (ts,cs) of
+ ([_,_,_,_],[_,_]) -> return (f ts)
+ _ -> Nothing
+ }
+
{-------------------------------------------------------------------------------
Evaluation
@@ -334,9 +413,9 @@ Interact with axioms
interactTopAdd :: [Xi] -> Xi -> [Pair Type]
interactTopAdd [s,t] r
- | Just 0 <- mbZ = [ s === num 0, t === num 0 ] -- (s + t ~ 0) => (s ~ 0, t ~ 0)
- | Just x <- mbX, Just z <- mbZ, Just y <- minus z x = [t === num y] -- (5 + t ~ 8) => (t ~ 3)
- | Just y <- mbY, Just z <- mbZ, Just x <- minus z y = [s === num x] -- (s + 5 ~ 8) => (s ~ 3)
+ | Just 0 <- mbZ = [ s === num 0, t === num 0 ] -- (s + t ~ 0) => (s ~ 0, t ~ 0)
+ | Just x <- mbX, Just z <- mbZ, Just y <- minus z x = [t === num y] -- (5 + t ~ 8) => (t ~ 3)
+ | Just y <- mbY, Just z <- mbZ, Just x <- minus z y = [s === num x] -- (s + 5 ~ 8) => (s ~ 3)
where
mbX = isNumLitTy s
mbY = isNumLitTy t
@@ -421,42 +500,42 @@ interactTopLeq _ _ = []
Interaction with inerts
-------------------------------------------------------------------------------}
-interactInertAdd :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
+interactInertAdd :: [Xi] -> Xi -> [Xi] -> Xi -> [(CoAxiomRule, Pair Type)]
interactInertAdd [x1,y1] z1 [x2,y2] z2
- | sameZ && tcEqType x1 x2 = [ y1 === y2 ]
- | sameZ && tcEqType y1 y2 = [ x1 === x2 ]
+ | sameZ && tcEqType x1 x2 = [ (axAddCancelL, y1 === y2) ]
+ | sameZ && tcEqType y1 y2 = [ (axAddCancelR, x1 === x2) ]
where sameZ = tcEqType z1 z2
interactInertAdd _ _ _ _ = []
-interactInertSub :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
+interactInertSub :: [Xi] -> Xi -> [Xi] -> Xi -> [(CoAxiomRule, Pair Type)]
interactInertSub [x1,y1] z1 [x2,y2] z2
- | sameZ && tcEqType x1 x2 = [ y1 === y2 ]
- | sameZ && tcEqType y1 y2 = [ x1 === x2 ]
+ | sameZ && tcEqType x1 x2 = [ (axSubCancelL, y1 === y2) ]
+ | sameZ && tcEqType y1 y2 = [ (axSubCancelR, x1 === x2) ]
where sameZ = tcEqType z1 z2
interactInertSub _ _ _ _ = []
-interactInertMul :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
+interactInertMul :: [Xi] -> Xi -> [Xi] -> Xi -> [(CoAxiomRule, Pair Type)]
interactInertMul [x1,y1] z1 [x2,y2] z2
- | sameZ && known (/= 0) x1 && tcEqType x1 x2 = [ y1 === y2 ]
- | sameZ && known (/= 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
+ | sameZ && known (/= 0) x1 && tcEqType x1 x2 = [ (axMulCancelL, y1 === y2) ]
+ | sameZ && known (/= 0) y1 && tcEqType y1 y2 = [ (axMulCancelR, x1 === x2) ]
where sameZ = tcEqType z1 z2
interactInertMul _ _ _ _ = []
-interactInertExp :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
+interactInertExp :: [Xi] -> Xi -> [Xi] -> Xi -> [(CoAxiomRule, Pair Type)]
interactInertExp [x1,y1] z1 [x2,y2] z2
- | sameZ && known (> 1) x1 && tcEqType x1 x2 = [ y1 === y2 ]
- | sameZ && known (> 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
+ | sameZ && known (> 1) x1 && tcEqType x1 x2 = [ (axExpCancelL, y1 === y2) ]
+ | sameZ && known (> 0) y1 && tcEqType y1 y2 = [ (axExpCancelR, x1 === x2) ]
where sameZ = tcEqType z1 z2
interactInertExp _ _ _ _ = []
-interactInertLeq :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
+interactInertLeq :: [Xi] -> Xi -> [Xi] -> Xi -> [(CoAxiomRule, Pair Type)]
interactInertLeq [x1,y1] z1 [x2,y2] z2
- | bothTrue && tcEqType x1 y2 && tcEqType y1 x2 = [ x1 === y1 ]
- | bothTrue && tcEqType y1 x2 = [ (x1 <== y2) === bool True ]
- | bothTrue && tcEqType y2 x1 = [ (x2 <== y1) === bool True ]
+ | bothTrue && tcEqType x1 y2 && tcEqType y1 x2 = [ (axLeqAsym, x1 === y1) ]
+ | bothTrue && tcEqType y1 x2 = [ (axLeqTrans, (x1 <== y2) === bool True) ]
+ | bothTrue && tcEqType y2 x1 = [ (axLeqTransSym, (x2 <== y1) === bool True) ]
where bothTrue = isJust $ do True <- isBoolLitTy z1
True <- isBoolLitTy z2
return ()
@@ -540,8 +619,336 @@ genLog x base = Just (exactLoop 0 x)
| otherwise = let s1 = s + 1 in s1 `seq` underLoop s1 (div i base)
+{- -----------------------------------------------------------------------------
+Equal Sum-of-Product representation
+----------------------------------------------------------------------------- -}
+
+{-
+SOP Grammar:
+
+SOP ::= Product+
+Product ::= Symbol+
+Symbol ::= Integer
+ | Variable
+ | SimpleExp
+ | ComplexExp
+
+SimpleExp ::= Variable "^" Integer
+ComplexExp ::= SOPS "^" ProductS
+
+SOPS ::= ProductS+
+ProductS ::= SymbolS+
+SymbolS ::= Integer
+ | Variable
+Valid SOP terms:
+xy + y^2
+(x+y)^(kz)
+Invalid SOP terms:
+(xy)^2
+
+-}
+type Result = Maybe Bool
+
+data Op = Add | Sub | Mul | Exp
+ deriving Eq
+
+data Expr
+ = Lit Integer
+ | Var TyVar
+ | Op Op Expr Expr
+ deriving Eq
+
+data Symbol = I Integer
+ | V TyVar
+ | E SOP Product
+ deriving (Eq,Ord)
+
+newtype Product = P { unP :: [Symbol] }
+ deriving (Eq,Ord)
+
+newtype SOP = S { unS :: [Product] }
+ deriving (Eq,Ord)
+
+instance Outputable Expr where
+ ppr (Lit i) = integer i
+ ppr (Var v) = ppr v
+ ppr (Op op e1 e2) = parens (ppr e1) <+> ppr op <+> parens (ppr e2)
+
+instance Outputable Op where
+ ppr Add = text "+"
+ ppr Sub = text "-"
+ ppr Mul = text "*"
+ ppr Exp = text "^"
+
+instance Outputable SOP where
+ ppr = hcat . punctuate (text " + ") . map ppr . unS
+
+instance Outputable Product where
+ ppr = hcat . punctuate (text "*") . map ppr . unP
+
+instance Outputable Symbol where
+ ppr (I i) = integer i
+ ppr (V s) = ppr s
+ ppr (E b e) = case (pprSimple b, pprSimple (S [e])) of
+ (bS,eS) -> bS <+> text "^" <+> eS
+ where
+ pprSimple (S [P [I i]]) = integer i
+ pprSimple (S [P [V v]]) = ppr v
+ pprSimple sop = text "(" <+> ppr sop <+> text ")"
+
+sopRelevantTyCon :: TyCon -> Bool
+sopRelevantTyCon tc = tc `elem` [ typeNatAddTyCon, typeNatSubTyCon, typeNatMulTyCon, typeNatExpTyCon ]
+
+sopToProp :: TyCon -> [Expr] -> Expr -> (Expr,Expr)
+sopToProp tc [e1,e2] e = ((Op (sopToOp tc) e1 e2),e)
+sopToProp tc es _ = pprPanic "sopArith"
+ $ text "Unexpected arity for" <+> ppr tc
+ <+> int (length es)
+
+sopToOp :: TyCon -> Op
+sopToOp tc
+ | tc == typeNatAddTyCon = Add
+ | tc == typeNatSubTyCon = Sub
+ | tc == typeNatMulTyCon = Mul
+ | tc == typeNatExpTyCon = Exp
+ | otherwise = pprPanic "tyConToOp" (ppr tc)
+
+sopToExpr :: Type -> Expr
+sopToExpr ty
+ | Just x <- getTyVar_maybe ty = Var x
+ | Just x <- isNumLitTy ty = Lit x
+ | otherwise = pprPanic "typeToTerm" (ppr ty)
+
+mergeWith :: (a -> a -> Either a a) -> [a] -> [a]
+mergeWith _ [] = []
+mergeWith op (f:fs) = case partitionEithers $ map (`op` f) fs of
+ ([],_) -> f : mergeWith op fs
+ (updated,untouched) -> mergeWith op (updated ++ untouched)
+
+isSimple :: Symbol -> Bool
+isSimple (I _) = True
+isSimple (V _) = True
+isSimple (E (S [P [_]]) _) = True
+isSimple _ = False
+
+-- | Simplify 'complex' symbols
+reduceSymbol :: Symbol -> Symbol
+reduceSymbol (E _ (P [(I 0)])) = I 1 -- x^0 ==> 1
+reduceSymbol (E (S [P [I 0]]) _ ) = I 0 -- 0^x ==> 0
+reduceSymbol (E (S [P [(I i)]]) (P [(I j)])) = I (i ^ j) -- 2^3 ==> 8
+
+-- (k ^ i) ^ j ==> k ^ (i * j)
+reduceSymbol (E (S [P [(E k i)]]) j ) = E k (P . sort . map reduceSymbol
+ $ mergeWith mergeS (unP i ++ unP j))
+
+reduceSymbol s = s
+
+-- | Merge two symbols of a Product term
+mergeS :: Symbol -> Symbol -> Either Symbol Symbol
+mergeS (I i) (I j) = Left (I (i * j)) -- 8 * 7 ==> 56
+mergeS (I 1) r = Left r -- 1 * x ==> x
+mergeS l (I 1) = Left l -- x * 1 ==> x
+mergeS (I 0) _ = Left (I 0) -- 0 * x ==> 0
+mergeS _ (I 0) = Left (I 0) -- x * 0 ==> 0
+
+-- x * x^4 ==> x^5
+mergeS s (E (S [P [s']]) (P [I i]))
+ | s == s'
+ = Left (E (S [P [s']]) (P [I (i + 1)]))
+
+-- x^4 * x ==> x^5
+mergeS (E (S [P [s']]) (P [I i])) s
+ | s == s'
+ = Left (E (S [P [s']]) (P [I (i + 1)]))
+
+-- y*y ==> y^2
+mergeS l r
+ | l == r && isSimple l
+ = Left (E (S [P [l]]) (P [I 2]))
+
+mergeS l _ = Right l
+
+-- | Merge two products of a SOP term
+mergeP :: Product -> Product -> Either Product Product
+-- 2xy + 3xy ==> 5xy
+mergeP (P ((I i):is)) (P ((I j):js))
+ | is == js = Left . P $ (I (i + j)) : is
+-- 2xy + xy ==> 3xy
+mergeP (P ((I i):is)) (P js)
+ | is == js = Left . P $ (I (i + 1)) : is
+-- xy + 2xy ==> 3xy
+mergeP (P is) (P ((I j):js))
+ | is == js = Left . P $ (I (j + 1)) : is
+-- xy + xy ==> 2xy
+mergeP (P is) (P js)
+ | is == js = Left . P $ (I 2) : is
+ | otherwise = Right $ P is
+
+-- | Expand or Simplify 'complex' exponentials
+expandExp :: SOP -> SOP -> SOP
+-- b^1 ==> b
+expandExp b (S [P [(I 1)]]) = b
+
+-- x^y ==> x^y
+expandExp b@(S [P [_]]) (S [e@(P (_:_))]) = S [P [reduceSymbol (E b e)]]
+
+-- (x + 2)^2 ==> x^2 + 4xy + 4
+expandExp b (S [P [(I i)]]) = foldr1 mergeSOPMul (replicate (fromInteger i) b)
+
+-- (x + 2)^x ==> (x+2)^x
+expandExp b (S [e@(P [_])]) = S [P [reduceSymbol (E b e)]]
+
+-- (x + 2)^(x + 2) ==> (x + y)^y + x^2 + 4xy + 4
+expandExp b (S e) = foldr1 mergeSOPMul (map (expandExp b . S . (:[])) e)
+
+toSOP :: Expr -> SOP
+toSOP (Lit i) = S [P [I i]]
+toSOP (Var s) = S [P [V s]]
+toSOP (Op Add e1 e2) = mergeSOPAdd (toSOP e1) (toSOP e2)
+toSOP (Op Sub e1 e2) = mergeSOPAdd (toSOP e1) (mergeSOPMul (S [P [I (-1)]]) (toSOP e2))
+toSOP (Op Mul e1 e2) = mergeSOPMul (toSOP e1) (toSOP e2)
+toSOP (Op Exp e1 e2) = expandExp (toSOP e1) (toSOP e2)
+
+zeroP :: Product -> Bool
+zeroP (P ((I 0):_)) = True
+zeroP _ = False
+
+simplifySOP :: SOP -> SOP
+simplifySOP
+ = S
+ . sort . filter (not . zeroP)
+ . mergeWith mergeP
+ . map (P . sort . map reduceSymbol . mergeWith mergeS . unP)
+ . unS
+
+mergeSOPAdd :: SOP -> SOP -> SOP
+mergeSOPAdd (S sop1) (S sop2) = simplifySOP $ S (sop1 ++ sop2)
+
+mergeSOPMul :: SOP -> SOP -> SOP
+mergeSOPMul (S sop1) (S sop2)
+ = simplifySOP
+ . S
+ $ concatMap (zipWith (\p1 p2 -> P (unP p1 ++ unP p2)) sop1 . repeat) sop2
+
+solveSOP :: [(Expr,Expr)] -> (Expr,Expr) -> Result
+solveSOP assmps goal = if uncurry eqFV goalSub
+ then Just (uncurry (==) goalSOP)
+ else hasShared
+ where
+ -- Assumptions that have a have variable as either the LHS or RHS
+ -- Transforms: (E1,E2) => (TV,E)
+ varAssmps = simpleEqs assmps
+ -- Run substitutions with `varAssmps` on `varAssmps`
+ -- Transforms: [(X1,A + B),(X2,X1 + C)] => [(X1,A + B),(X2,(A + B) + C)]
+ assmpsSub = map (second (substExprs varAssmps)) varAssmps
+ -- Substitute assumptions in the goal
+ goalSub = (substExprs assmpsSub *** substExprs assmpsSub) goal
+ -- Put the LHS and RHS of the goal in SOP form
+ goalSOP = (toSOP *** toSOP) goalSub
+
+ -- Connect assumptions that have the same variable LHS
+ -- Transforms: [(X1,A + B),(X1,2*C),(X2,8*D)] => [(A+B,2*C)]
+ matchingAssmps = eqVarEqs assmpsSub
+ matchingAssmpsSOP = map (toSOP *** toSOP) matchingAssmps
+ -- Determine if goal only differs by a common factor with one of the assumptions
+ sharedFactors = map (`sharedFactor` goalSOP) matchingAssmpsSOP
+ hasShared = foldr (\l r -> case l of {Just b -> Just b ; Nothing -> r}) Nothing sharedFactors
+
+simpleEqs :: [(Expr,Expr)] -> [(TyVar,Expr)]
+simpleEqs = catMaybes . map simpleEq
+
+simpleEq :: (Expr,Expr) -> Maybe (TyVar,Expr)
+simpleEq (Var v,e) = Just (v,e)
+simpleEq (e,Var v) = Just (v,e)
+simpleEq _ = Nothing
+
+substExprs :: [(TyVar,Expr)] -> Expr -> Expr
+substExprs es e = foldl (flip substExpr) e es
+
+substExpr :: (TyVar,Expr) -> Expr -> Expr
+substExpr _ (Lit i) = Lit i
+substExpr (tv,e) (Var tv') | tv == tv' = e
+ | otherwise = Var tv'
+substExpr s (Op o e1 e2) = Op o (substExpr s e1) (substExpr s e2)
+
+fvExpr :: Expr -> UniqSet TyVar
+fvExpr (Lit _) = emptyUniqSet
+fvExpr (Var v) = unitUniqSet v
+fvExpr (Op _ e1 e2) = fvExpr e1 `unionUniqSets` fvExpr e2
+
+eqFV :: Expr -> Expr -> Bool
+eqFV = (==) `on` fvExpr
+
+eqVarEqs :: [(TyVar,Expr)] -> [(Expr,Expr)]
+eqVarEqs [] = []
+eqVarEqs ((tv,e1):es) = this ++ rest
+ where
+ this = map (\(_,e2) -> (e1,e2)) $ filter ((== tv) . fst) es
+ rest = eqVarEqs es
+
+type Factor = Product
+
+-- | Divide two Product terms, returns 'Nothing' the result is a fractional term
+divProduct :: Product -> Product -> Maybe Factor
+divProduct as bs | as == bs = Just (P [I 1])
+
+divProduct (P ((I a):as)) (P [I b])
+ = case (a `divMod` b) of
+ (z,0) -> Just (P ((I z):as))
+ _ -> Nothing
+
+divProduct (P ((I a):as)) (P ((I b):bs))
+ = case (a `divMod` b,as \\ bs) of
+ ((z,0),zs) | length zs < length as -> Just (P ((I z):zs))
+ _ -> Nothing
+
+divProduct (P as) (P bs)
+ = case (as \\ bs) of
+ zs | length zs < length as -> Just (P zs)
+ _ -> Nothing
+
+-- | Divide two SOP terms, returns 'Nothing' when:
+-- * The SOP two terms are of unequal length
+-- * A product division returns nothing
+-- * The factors between two subsequent product divisions differ
+divSOP :: Product -> SOP -> SOP -> Maybe Factor
+divSOP p (S []) (S []) = return p
+divSOP _ (S []) _ = Nothing
+divSOP _ _ (S []) = Nothing
+
+divSOP p (S (a:as)) (S (b:bs))
+ = do z <- if a < b then divProduct a b
+ else divProduct b a
+ if eqProduct p z
+ then divSOP z (S as) (S bs)
+ else Nothing
+ where
+ eqProduct (P []) _ = True
+ eqProduct a b = a == b
+-- | Remove common product terms from two SOP terms
+removeShift :: SOP -> SOP -> (SOP,SOP)
+removeShift (S s1) (S s2) = (S (s1 \\ s2), S (s2 \\ s1))
+-- | Given two equalities, determine ,if it exists, the common factor between
+-- the respective LHS's and RHS' of the equalities.
+factorOF :: (SOP,SOP) -> (SOP,SOP) -> (Maybe Factor,Maybe Factor)
+factorOF (aL,aR) (gL,gR) = (divSOP (P []) aL' gL',divSOP (P []) aR' gR')
+ where
+ (aL',aR') = removeShift aL aR
+ (gL',gR') = removeShift gL gR
+
+-- | Given two equalities, determine if they are divisible by a common factor.
+-- Returns:
+-- * Just True, if the equalities share a common factor
+-- * Just False, if the factor for the LHS's differs from the RHS's
+-- * Nothing, if a factor cannot be found.
+sharedFactor :: (SOP,SOP) -> (SOP,SOP) -> Maybe Bool
+sharedFactor asmp goal@(l,r) = case factorOF asmp goal of
+ (Just fL, Just fR) -> Just (fL == fR)
+ _ -> case factorOF asmp (r,l) of
+ (Just fL', Just fR') -> Just (fL' == fR')
+ _ -> Nothing
diff --git a/compiler/types/CoAxiom.lhs b/compiler/types/CoAxiom.lhs
index a0a4974..743dae0 100644
--- a/compiler/types/CoAxiom.lhs
+++ b/compiler/types/CoAxiom.lhs
@@ -528,7 +528,7 @@ data BuiltInSynFamily = BuiltInSynFamily
{ sfMatchFam :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
, sfInteractTop :: [Type] -> Type -> [Eqn]
, sfInteractInert :: [Type] -> Type ->
- [Type] -> Type -> [Eqn]
+ [Type] -> Type -> [(CoAxiomRule, Eqn)]
}
-- Provides default implementations that do nothing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment