Last active
September 8, 2024 12:58
-
-
Save Lysxia/f84d066ff0b32e1f888e9f689f4c8426 to your computer and use it in GitHub Desktop.
Generic Scott encoding of algebraic data types https://old.reddit.com/r/haskell/comments/1fbl003/update_a_function_to_replace_all_case_expressions/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE | |
AllowAmbiguousTypes, | |
DeriveGeneric, | |
EmptyCase, | |
TypeOperators, | |
FlexibleInstances, | |
FlexibleContexts, | |
MultiParamTypeClasses, | |
PolyKinds, | |
TypeApplications, | |
UndecidableInstances #-} | |
{-# OPTIONS_GHC -Wall #-} | |
import GHC.Generics | |
gcase :: forall t b r. (Generic t, GCaseSum (Rep t) r b r) => t -> b | |
gcase t = gCaseSum @(Rep t) @r @b @r (from t) id | |
-- Two classes: GCaseSum to decompose sums, GCaseProd to decompose products | |
-- Sums: f :+: g -> (f -> r) -> (g -> r) -> r | |
-- - to process f, we must consume (f -> r) and ignore (g -> r) | |
-- - to process g, we must ignore (f -> r) and consume (g -> r) | |
-- | |
-- Define gCaseSum to consume arguments and gCaseSumSkip to ignore arguments in a sum type g. | |
-- - gCaseSum gets passed a function (r -> a) that expects the final result r after processing g, | |
-- ignores extra arguments, and returns r. | |
-- - gCaseSumSkip ignores the arguments corresponding to g, | |
-- and uses the function contained in a to process the remaining arguments. | |
-- | |
-- For example, for (Either x y) it will look like: | |
-- - At the toplevel, we have (GCaseSum (Either x y) a b) with a and b shown below: | |
-- gCaseSum :: Either x y -> (r -> r) -> (x -> r) -> (y -> r) -> r | |
-- r=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b | |
-- (gCaseSumSkip will not be used) | |
-- - When processing the Left side (GCaseSum x a b): | |
-- gCaseSum :: x -> (r -> (y -> r) -> r) -> (x -> r) -> (y -> r) -> r | |
-- ^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b | |
-- gCaseSumSkip :: ((y -> r) -> r) -> (x -> r) -> (y -> r) -> r | |
-- ^^^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b | |
-- - When processing the Right side (GCaseSum y a b): | |
-- gCaseSum :: y -> (r -> r) -> (y -> r) -> r | |
-- r=a ^^^^^^^^^^^^^=b | |
-- (gCaseSumSkip will not be used) | |
class GCaseSum g a b r where | |
gCaseSum :: forall x. g x -> (r -> a) -> b | |
gCaseSumSkip :: a -> b | |
instance (GCaseSum f b c r, GCaseSum g a b r) => GCaseSum (f :+: g) a c r where | |
-- consume arguments for f, skip arguments for g | |
gCaseSum (L1 f) = gCaseSum @f @b @c f . fmap (gCaseSumSkip @g @a @b @r) | |
-- skip arguments for f, consume arguments for g | |
gCaseSum (R1 g) = gCaseSumSkip @f @b @c @r . gCaseSum @g @a @b g | |
-- skip f and g | |
gCaseSumSkip = gCaseSumSkip @f @b @c @r . gCaseSumSkip @g @a @b @r | |
-- Empty sum (Void) | |
-- gcase :: Void -> r | |
instance (a ~ b) => GCaseSum V1 a b r where | |
gCaseSum v = case v of {} | |
gCaseSumSkip = id | |
-- Unwrap toplevel M1 | |
instance GCaseSum f a b r => GCaseSum (M1 D w f) a b r where | |
gCaseSum (M1 f) = gCaseSum @f @a @b @r f | |
gCaseSumSkip = gCaseSumSkip @f @a @b @r | |
-- Unwrap leaf M1 (we are now looking at a single constructor, call GCaseProd) | |
instance (ca ~ (c -> a), GCaseProd f c r) => GCaseSum (M1 C w f) a ca r where | |
gCaseSum (M1 f) k c = k (gCaseProd f c) | |
gCaseSumSkip a _ = a | |
-- Products: (x, y) -> (x -> y -> r) -> r | |
-- | |
-- - Toplevel instance (GCaseProd (x, y) c d): | |
-- gCaseProd :: (x, y) -> (x -> y -> r) -> r | |
-- ^^^^^^^^^^^^^=c r=d | |
-- - When processing x (instance GCaseProd x c d): | |
-- gCaseProd :: x -> (x -> y -> r) -> y -> r | |
-- ^^^^^^^^^^^^^=c ^^^^^^=d | |
-- - When processing y (instance GCaseProd y c d): | |
-- gCaseProd :: y -> (y -> r) -> r | |
-- ^^^^^^^^=c r=d | |
-- | |
-- In both of the leaf cases (instance for M1 S), gCaseProd = \a k -> k a (= flip ($)) | |
-- The instance for (:*:) combines them using (.) | |
class GCaseProd g c d where | |
gCaseProd :: forall x. g x -> c -> d | |
instance (GCaseProd f c d, GCaseProd g d e) => GCaseProd (f :*: g) c e where | |
gCaseProd (f :*: g) = gCaseProd g . gCaseProd @f @c @d f | |
instance (ad ~ (a -> d)) => GCaseProd (M1 S w (K1 i a)) ad d where | |
gCaseProd (M1 (K1 a)) k = k a | |
-- Empty product (unit) | |
-- gcase :: () -> r -> r | |
-- gcase :: Maybe a -> r -> (a -> r) -> r | |
-- ^ | |
instance (r ~ r') => GCaseProd U1 r r' where | |
gCaseProd U1 r = r | |
-- | |
-- Assert equality | |
(=?) :: (Eq a, Show a) => a -> a -> IO () | |
(=?) x y = if x == y then pure () else error (show x ++ " /= " ++ show y) | |
-- Example with three constructors and three fields | |
data Three a = Uno a | Dos a a | Tres a a a | |
deriving (Generic) | |
-- Should print nothing if the tests pass | |
main :: IO () | |
main = do | |
gcase (Left (3 :: Int)) id id =? 3 | |
gcase (Right (3 :: Int)) id id =? 3 | |
gcase (Just (3 :: Int)) 0 id =? 3 | |
gcase (40, 2) (+) =? (42 :: Int) | |
gcase () 42 =? (42 :: Int) | |
gcase (Tres 100 20 3) id (+) (\x y z -> x + y + z) =? (123 :: Int) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment