Skip to content

Instantly share code, notes, and snippets.

@Lysxia
Last active September 8, 2024 12:58
Show Gist options
  • Save Lysxia/f84d066ff0b32e1f888e9f689f4c8426 to your computer and use it in GitHub Desktop.
Save Lysxia/f84d066ff0b32e1f888e9f689f4c8426 to your computer and use it in GitHub Desktop.
{-# LANGUAGE
AllowAmbiguousTypes,
DeriveGeneric,
EmptyCase,
TypeOperators,
FlexibleInstances,
FlexibleContexts,
MultiParamTypeClasses,
PolyKinds,
TypeApplications,
UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
import GHC.Generics
gcase :: forall t b r. (Generic t, GCaseSum (Rep t) r b r) => t -> b
gcase t = gCaseSum @(Rep t) @r @b @r (from t) id
-- Two classes: GCaseSum to decompose sums, GCaseProd to decompose products
-- Sums: f :+: g -> (f -> r) -> (g -> r) -> r
-- - to process f, we must consume (f -> r) and ignore (g -> r)
-- - to process g, we must ignore (f -> r) and consume (g -> r)
--
-- Define gCaseSum to consume arguments and gCaseSumSkip to ignore arguments in a sum type g.
-- - gCaseSum gets passed a function (r -> a) that expects the final result r after processing g,
-- ignores extra arguments, and returns r.
-- - gCaseSumSkip ignores the arguments corresponding to g,
-- and uses the function contained in a to process the remaining arguments.
--
-- For example, for (Either x y) it will look like:
-- - At the toplevel, we have (GCaseSum (Either x y) a b) with a and b shown below:
-- gCaseSum :: Either x y -> (r -> r) -> (x -> r) -> (y -> r) -> r
-- r=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b
-- (gCaseSumSkip will not be used)
-- - When processing the Left side (GCaseSum x a b):
-- gCaseSum :: x -> (r -> (y -> r) -> r) -> (x -> r) -> (y -> r) -> r
-- ^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b
-- gCaseSumSkip :: ((y -> r) -> r) -> (x -> r) -> (y -> r) -> r
-- ^^^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b
-- - When processing the Right side (GCaseSum y a b):
-- gCaseSum :: y -> (r -> r) -> (y -> r) -> r
-- r=a ^^^^^^^^^^^^^=b
-- (gCaseSumSkip will not be used)
class GCaseSum g a b r where
gCaseSum :: forall x. g x -> (r -> a) -> b
gCaseSumSkip :: a -> b
instance (GCaseSum f b c r, GCaseSum g a b r) => GCaseSum (f :+: g) a c r where
-- consume arguments for f, skip arguments for g
gCaseSum (L1 f) = gCaseSum @f @b @c f . fmap (gCaseSumSkip @g @a @b @r)
-- skip arguments for f, consume arguments for g
gCaseSum (R1 g) = gCaseSumSkip @f @b @c @r . gCaseSum @g @a @b g
-- skip f and g
gCaseSumSkip = gCaseSumSkip @f @b @c @r . gCaseSumSkip @g @a @b @r
-- Empty sum (Void)
-- gcase :: Void -> r
instance (a ~ b) => GCaseSum V1 a b r where
gCaseSum v = case v of {}
gCaseSumSkip = id
-- Unwrap toplevel M1
instance GCaseSum f a b r => GCaseSum (M1 D w f) a b r where
gCaseSum (M1 f) = gCaseSum @f @a @b @r f
gCaseSumSkip = gCaseSumSkip @f @a @b @r
-- Unwrap leaf M1 (we are now looking at a single constructor, call GCaseProd)
instance (ca ~ (c -> a), GCaseProd f c r) => GCaseSum (M1 C w f) a ca r where
gCaseSum (M1 f) k c = k (gCaseProd f c)
gCaseSumSkip a _ = a
-- Products: (x, y) -> (x -> y -> r) -> r
--
-- - Toplevel instance (GCaseProd (x, y) c d):
-- gCaseProd :: (x, y) -> (x -> y -> r) -> r
-- ^^^^^^^^^^^^^=c r=d
-- - When processing x (instance GCaseProd x c d):
-- gCaseProd :: x -> (x -> y -> r) -> y -> r
-- ^^^^^^^^^^^^^=c ^^^^^^=d
-- - When processing y (instance GCaseProd y c d):
-- gCaseProd :: y -> (y -> r) -> r
-- ^^^^^^^^=c r=d
--
-- In both of the leaf cases (instance for M1 S), gCaseProd = \a k -> k a (= flip ($))
-- The instance for (:*:) combines them using (.)
class GCaseProd g c d where
gCaseProd :: forall x. g x -> c -> d
instance (GCaseProd f c d, GCaseProd g d e) => GCaseProd (f :*: g) c e where
gCaseProd (f :*: g) = gCaseProd g . gCaseProd @f @c @d f
instance (ad ~ (a -> d)) => GCaseProd (M1 S w (K1 i a)) ad d where
gCaseProd (M1 (K1 a)) k = k a
-- Empty product (unit)
-- gcase :: () -> r -> r
-- gcase :: Maybe a -> r -> (a -> r) -> r
-- ^
instance (r ~ r') => GCaseProd U1 r r' where
gCaseProd U1 r = r
--
-- Assert equality
(=?) :: (Eq a, Show a) => a -> a -> IO ()
(=?) x y = if x == y then pure () else error (show x ++ " /= " ++ show y)
-- Example with three constructors and three fields
data Three a = Uno a | Dos a a | Tres a a a
deriving (Generic)
-- Should print nothing if the tests pass
main :: IO ()
main = do
gcase (Left (3 :: Int)) id id =? 3
gcase (Right (3 :: Int)) id id =? 3
gcase (Just (3 :: Int)) 0 id =? 3
gcase (40, 2) (+) =? (42 :: Int)
gcase () 42 =? (42 :: Int)
gcase (Tres 100 20 3) id (+) (\x y z -> x + y + z) =? (123 :: Int)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment