Skip to content

Instantly share code, notes, and snippets.

@plaidfinch
Last active August 29, 2015 14:12
Show Gist options
  • Save plaidfinch/c132590f78627f76144f to your computer and use it in GitHub Desktop.
Save plaidfinch/c132590f78627f76144f to your computer and use it in GitHub Desktop.
Playing with sums of products
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
module SOP where
import Data.Either
import Data.Proxy
import Data.Void
import Control.Arrow
import Data.Type.Equality
import Data.Constraint
infixr 5 :*:
infixr 3 *
infixr 2 +
infixr 3 |*|
--------------------------------------
-- Sums of products (SOPs) as GADTs --
--------------------------------------
-- Lots of this formulation is borrowed from Andres Löh's WGP 2014 paper, "True Sums of Products,"
-- and its implementation on Hackage, Generics.SOP. However, I wanted to take certain liberties with
-- notation and presentation, so I've replicated the parts I care about here.
type SOP (f :: k -> *) (xss :: [[k]]) = Sum (Product f) xss
-- An n-ary product of f-modality things, indexed by the list of things
data Product (f :: k -> *) (xs :: [k]) where
(:*:) :: f x -> Product f xs -> Product f (x ': xs)
Nil :: Product f '[]
-- An n-ary sum of f-modality things, indexed by the list of things
data Sum (f :: k -> *) (xs :: [k]) where
This :: f x -> Sum f (x ': xs)
That :: Sum f xs -> Sum f (x ': xs)
---------------------------------------------
-- Some useful functors you've seen before --
---------------------------------------------
newtype K a b = K { getK :: a } -- i.e. Data.Functor.Constant
newtype I a = I { getI :: a } -- i.e. Data.Functor.Identity
newtype (f ∘ g) a = O { getO :: f (g a) } -- i.e. Data.Functor.Compose but with 100% more PolyKinds
----------------------------------------
-- Some tools for constructing proofs --
----------------------------------------
-- These are nicer than constantly writing out "case x of Refl -> case y of Refl -> z"
-- Now, instead, you can do things like "x ==> y ==> z"
-- Given a proof of equality, discharge that proof obligation for a value which needs it
(==>) :: (x :~: y) -> ((x ~ y) => z) -> z
Refl ==> z = z
-- Given a proof of a constraint, discharge that proof obligation for a value which needs it
(~~>) :: Dict c -> (c => z) -> z
Dict ~~> z = z
-------------------------------------------------
-- Some simple functions and proofs about SOPs --
-------------------------------------------------
-- Split a product into the pair of its head and tail
uncons :: Product f (x ': xs) -> (f x, Product f xs)
uncons (x :*: xs) = (x, xs)
-- Split a sum into either its head or its tail
check :: Sum f (x ': xs) -> Either (f x) (Sum f xs)
check (This x) = Left x
check (That x) = Right x
-- Extract the only option from an unary sum
onlyOption :: Sum f '[a] -> f a
onlyOption = either id (absurd . noSumEmpty) . check
-- Extract the only field from an unary product
onlyField :: Product f '[a] -> f a
onlyField = fst . uncons
-- A proof that no empty sum can be constructed
noSumEmpty :: Sum f '[] -> Void
noSumEmpty = \case
-- ...And because we're in the land of constructive logic, saying that no sums are empty
-- is not as strong as saying that all sums are non-empty, so we also want to say:
anySumNonEmpty :: Sum f xs -> xs :~: (Head xs ': Tail xs)
anySumNonEmpty s = case s of
This _ -> Refl
That _ -> Refl
-- Type-level head and tail of lists:
type family Head (xs :: [k]) :: k where Head (x ': xs) = x
type family Tail (xs :: [k]) :: [k] where Tail (x ': xs) = xs
-- Type level append for lists
type family as ++ bs where
'[] ++ bs = bs
(a ': as) ++ bs = a ': (as ++ bs)
-- A type-restricted version of list-append, to be used in talking about SOPs
type (as :: [[k]]) + (bs :: [[k]]) = as ++ bs
-------------------------------------------------------------------
-- Multiplication and addition of sums and products individually --
-------------------------------------------------------------------
-- Multiply (concatenate) a pair of two n-ary products.
-- Dual in one sense to 'plus'; dual in another sense to 'splitProduct'
times :: (Product f xs, Product f ys) -> Product f (xs ++ ys)
times (Nil, ys) = ys
times ((x :*: xs), ys) = x :*: times (xs, ys)
-- Add a sum of two n-ary sums.
-- Dual in one sense to 'times'; dual in another sense to 'splitSum'
plus :: forall f xs ys. (KnownProduct xs) => Either (Sum f xs) (Sum f ys) -> Sum f (xs ++ ys)
plus (Left x) = appendSum x (Proxy :: Proxy ys)
plus (Right y) = prependSum (productProxy :: Product Proxy xs) y
-- Inject an n-ary sum into a larger n-ary sum by appending extra types onto its list.
appendSum :: forall f xs ys. Sum f xs -> Proxy ys -> Sum f (xs ++ ys)
appendSum (This x) Proxy = This x
appendSum (That x) Proxy = That (appendSum x (Proxy :: Proxy ys))
-- Inject an n-ary sum into a larger n-ary sum by prepending extra types onto its list.
prependSum :: Product Proxy xs -> Sum f ys -> Sum f (xs ++ ys)
prependSum Nil = id
prependSum (x :*: xs) = That . prependSum xs
-- Generate a Product of Proxies for any type-level list.
-- In an ideal world, we'd be able to show that for any (xs :: [k]), we can generate a (Product Proxy xs)
-- and thus discharge the silly floating obligation to always stick KnownProduct in our class constraints.
-- As far as I know, this is not an ideal world.
class KnownProduct (xs :: [k]) where productProxy :: Product Proxy xs
instance KnownProduct '[] where productProxy = Nil
instance (KnownProduct xs) => KnownProduct (x ': xs) where productProxy = Proxy :*: productProxy
----------------------------------------------------------------------------------
-- A bunch of constraints programming leading up to an elegant 2-D KnownProduct --
----------------------------------------------------------------------------------
-- Some unary constraint holds for all items in the list
type family All (c :: k -> Constraint) (xs :: [k]) :: Constraint where
All c '[] = ()
All c (x ': xs) = (c x, All c xs)
-- Map, lifted to the type level
type family Map (f :: k -> l) (xs :: [k]) :: [l] where
Map f '[] = '[]
Map f (x ': xs) = f x ': Map f xs
-- If for all x in xs, (c x) holds, produce a list of dictionaries as evidence
dicts :: All c xs => Proxy c -> Product Proxy xs -> Product (Dict ∘ c) xs
dicts _ Nil = Nil
dicts p (_ :*: xs) = O Dict :*: dicts p xs
-- If for all x in xs, (c (f x)) holds, produce a list of dictionaries as evidence
dictsF :: All c (Map f xs) => Proxy c -> Proxy f -> Product Proxy xs -> Product (Dict ∘ c ∘ f) xs
dictsF _ _ Nil = Nil
dictsF p q (_ :*: xs) = O (O Dict) :*: dictsF p q xs
-- Given a list of dictionaries showing that for each x in xs, (c x) holds,
-- collapse them into a single dictionary showing that (c x) holds for all x.
collectDicts :: Product (Dict ∘ c) xs -> Dict (All c xs)
collectDicts Nil = Dict
collectDicts (O Dict :*: ds') = collectDicts ds' ~~> Dict
-- Given a list of dictionaries showing that for each x in xs, (c (f x)) holds,
-- collapse them into a single dictionary showing that (c (f x)) holds for all x.
collectDictsF :: Product (Dict ∘ c ∘ f) xs -> Dict (All c (Map f xs))
collectDictsF Nil = Dict
collectDictsF (O (O Dict) :*: ds') = collectDictsF ds' ~~> Dict
-- Given a function which needs (d x) and (c (f x)) to hold for all x in xs,
-- map it over a list of xs if provided explicit dictionary evidence that these
-- constraints do in fact hold. This is usually less useful than 'mapWithConstraints'.
mapWithDicts :: forall c d f g xs. (forall x. (d x, c (f x)) => f x -> g x)
-> Product (Dict ∘ c ∘ f) xs -> Product (Dict ∘ d) xs
-> Product f xs -> Product g xs
mapWithDicts t ds1 ds2 xs =
case xs of
(x :*: xs') ->
case uncons ds1 of
(O (O Dict), ds1') -> case uncons ds2 of
(O Dict, ds2') -> t x :*: mapWithDicts t ds1' ds2' xs'
Nil -> Nil
-- Map a function requiring certain constraints over a (Product f xs).
-- The function requires (c f), as well as (d (f x)) and (e x) for all x in xs.
mapWithConstraints :: forall c d e f g xs. (c f, All d (Map f xs), All e xs, KnownProduct xs)
=> Proxy c -> Proxy d -> Proxy e
-> (forall x. (c f, d (f x), e x) => f x -> g x)
-> Product f xs -> Product g xs
mapWithConstraints _ _ _ f xs =
mapWithDicts f (dictsF (Proxy :: Proxy d) (Proxy :: Proxy f) proxies) (dicts (Proxy :: Proxy e) proxies) xs
where proxies = productProxy :: Product Proxy xs
-- Generate a product of products of proxies in the shape and type of any list of types xs.
productProxies :: forall xs. (All KnownProduct xs, KnownProduct xs) => Product (Product Proxy) xs
productProxies =
allTrivialF (Proxy :: Proxy Proxy) (Proxy :: Proxy xs) ~~>
mapWithConstraints (Proxy :: Proxy Trivial)
(Proxy :: Proxy Trivial)
(Proxy :: Proxy KnownProduct)
(const productProxy)
(productProxy :: Product Proxy xs)
--------------------------------------------------------
-- A detour into proving trivial things... literally! --
--------------------------------------------------------
-- This stuff is necessary for the 'productProxies' term listed above.
-- A lot of hard work to prove something really trivial, but I'm not sure how to make it nicer.
-- A class with no members and no constraints; the trivial class
class Trivial (x :: k) where { }
instance Trivial x where { }
-- Some useful proofs about the trivial constraint...
-- Anything can be proven to satisfy the trivial constraint (the proof is trivial too)
anyTrivial :: Dict (Trivial a)
anyTrivial = Dict
-- For a (Product f xs), each x can be shown to satisfy the trivial constraint
eachTrivial :: forall f xs. (KnownProduct xs) => Product (Dict ∘ Trivial) xs
eachTrivial = go (productProxy :: Product Proxy xs)
where
go :: Product Proxy ys -> Product (Dict ∘ Trivial) ys
go Nil = Nil
go (_ :*: ys) = O anyTrivial :*: go ys
-- For a (Product f xs), each (f x) can be shown to satisfy the trivial constraint
eachTrivialF :: forall f xs. (KnownProduct xs) => Product (Dict ∘ Trivial ∘ f) xs
eachTrivialF = go (productProxy :: Product Proxy xs)
where
go :: Product Proxy ys -> Product (Dict ∘ Trivial ∘ f) ys
go Nil = Nil
go (_ :*: ys) = O (O anyTrivial) :*: go ys
-- For a (Product f xs), all of its elements x can be shown to satisfy the trivial constraint
allTrivial :: forall f xs. (KnownProduct xs) => Proxy f -> Proxy xs -> Dict (All Trivial xs)
allTrivial _ _ = collectDicts (eachTrivial :: Product (Dict ∘ Trivial) xs)
-- For a (Product f xs), all of its elements (f x) can be shown to satisfy the trivial constraint
allTrivialF :: forall f xs. (KnownProduct xs) => Proxy f -> Proxy xs -> Dict (All Trivial (Map f xs))
allTrivialF _ _ = collectDictsF (eachTrivialF :: Product (Dict ∘ Trivial ∘ f) xs)
--------------------------------------
-- Finally, multiplication of SOPs! --
--------------------------------------
-- Multiplication of n-ary sums-of-products using distributivity
type family (as :: [[k]]) * (bs :: [[k]]) :: [[k]] where
'[] * bs = '[]
(a ': as) * bs = Dist a bs + (as * bs)
-- Dist distributes the multiplication of a single product over a sum-of-products
-- You can think about it as the n-ary generalization of (x * (a + b) ==> (x * a) + (x * b))
type family Dist (a :: [k]) (bs :: [[k]]) :: [[k]] where
Dist a '[] = '[]
Dist a (b ': bs) = (a ++ b) ': Dist a bs
-- Value-level implementation of Dist: distribute the multiplication of a single product over a sum-of-products
dist :: Product f xs -> SOP f yss -> SOP f (Dist xs yss)
dist a (This x) = This (times (a, x))
dist a (That x) = That (dist a x)
-- (|*|) is the value-level SOP-multiplication function by distributivity!
(|*|) :: forall f x xs ys. (All KnownProduct (DistAll xs ys), KnownProduct (DistAll xs ys))
=> SOP f xs -> SOP f ys -> SOP f (xs * ys)
xs |*| ys =
anySumNonEmpty xs ==> go xs (productProxies :: Product (Product Proxy) (DistAll xs ys))
where
go :: forall a as. SOP f as -> Product (Product Proxy) (DistAll as ys) -> SOP f (as * ys)
go (This as') _ = appendSum (dist as' ys) (Proxy :: Proxy (Tail as * ys))
go (That as') ps = let (p, ps') = uncons ps
in anySumNonEmpty as' ==> prependSum p (go as' ps')
-- For each x in xs, distribute it over ys, and list all these results
-- This is the sequence of calls to 'dist' which are made by (|*|), and we need to be able to name that type
-- so that we can instantiate singletons for it to feed to prependSum.
type family DistAll (xs :: [[k]]) (ys :: [[k]]) :: [[[k]]] where
DistAll '[] ys = '[]
DistAll (x ': xs) ys = Dist x ys ': DistAll xs ys
---------------------------------------------
-- Splitting products and sums apart again --
---------------------------------------------
-- Split an n-ary product into the product of two n-ary products which when appended yield the original product
-- This can be thought of as an inverse operation to 'times', but because 'times' is not injective, splitProduct
-- is one-to-many, mapping a single product to both of any two products which could be multiplied to form it.
class SplitProduct a b c | a b -> c where
splitProduct :: Product f c -> (Product f a, Product f b)
instance SplitProduct '[] b b where
splitProduct x = (Nil, x)
instance (SplitProduct as bs cs) => SplitProduct (a ': as) bs (a ': cs) where
splitProduct :: forall f. Product f (a ': cs) -> (Product f (a ': as), Product f bs)
splitProduct (x :*: xs) = first (x :*:) (splitProduct xs :: (Product f as, Product f bs))
-- Split an n-ary sum into the sum of two n-ary sums which when added yield the original sum
-- This can be thought of as an inverse operation to 'plus', but because 'plus' is not injective, splitSum
-- is one-to-many, mapping a single sum to one of any two sums which could be added to form it.
class SplitSum a b c | a b -> c where
splitSum :: Sum f c -> Either (Sum f a) (Sum f b)
instance SplitSum '[] b b where
splitSum x = Right x
instance (SplitSum as bs cs) => SplitSum (a ': as) bs (a ': cs) where
splitSum :: forall f. Sum f (a ': cs) -> Either (Sum f (a ': as)) (Sum f bs)
splitSum (This x) = Left (This x)
splitSum (That x) = left That (splitSum x :: Either (Sum f as) (Sum f bs))
-- Open questions: how to encode all three functional dependencies which we'd want to: a b -> c, a c -> b, b c -> a?
-- We can't actually say this literally, because it violates the functional dependency condition in a literal sense.
-- But we really want to say that *for any given* a, b, c, two of the three uniquely determine the third. How?
-- Alternatively, is there a nice encoding using type families? Which also preserves the kind of inference we want?
-- Proofs which can be instantiated for any product/sum list that the split of the append can be identity
-- We'd really like to claim that forall (a :: [k]) (b :: [k]). SplitAppend{Product|Sum} a b holds,
-- (and this is true), but I'm not yet sure how to encode this particular universal constraint quantification.
type SplitAppendProduct a b = SplitProduct a b (a ++ b)
type SplitAppendSum a b = SplitSum a b (a ++ b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment