Skip to content

Instantly share code, notes, and snippets.

@schrammc
Last active December 31, 2019 13:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schrammc/57ab4bc114355563494c8ce1dfdcad30 to your computer and use it in GitHub Desktop.
Save schrammc/57ab4bc114355563494c8ce1dfdcad30 to your computer and use it in GitHub Desktop.
Haskell Open Unions without overlapping instances.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
import Unsafe.Coerce
import Data.Ord (Ordering(..))
import GHC.TypeLits (TypeError (..), ErrorMessage(..))
import Data.Kind (Constraint)
--------------------------------------------------------------------------------
--
-- This construct allows us to get the index of a type in a list of types at
-- runtime.
type family ElemIndexF (x :: a) (xs :: [a]) :: Nat where
ElemIndexF a '[] = TypeError (Text "Not a member of list")
ElemIndexF a (a ': xs) = Z
ElemIndexF a (b ': xs) = S (ElemIndexF a xs)
newtype N a = N Int
deriving Show
class NatToInt (a :: Nat) where
reifyNat :: N a
instance NatToInt 'Z where
reifyNat = N 0
instance (NatToInt a) => NatToInt ('S a) where
reifyNat =
let N n = reifyNat :: N a
in N (1 + n)
--------------------------------------------------------------------------------
--
-- A union type that wraps a value and has an Int, which is a runtime
-- representation of the value's type.
data Union (a :: [k]) where
Union :: Int -> b -> Union (a :: [k])
type Member x xs = NatToInt (ElemIndexF x xs)
-- | Wrap a value in an open union.
inject :: forall a b . Member a b => a -> Union (b :: [*])
inject x =
let N n = reifyNat :: N (ElemIndexF a b)
in Union n x
-- | Get a value of a concrete type from an open union. This will be nothing if
-- the value in the union is not actually of type 'a'.
project :: forall a b . Member a b => Union b -> Maybe a
project (Union k v) =
let N n = reifyNat :: N (ElemIndexF a b)
in if k == n
then Just (unsafeCoerce v)
else Nothing
type family Members xs ys :: Constraint where
Members '[] ys = ()
Members (x ': xs) ys = (Member x ys, Members xs ys)
type family Without (a :: k) (as :: [k]) :: [k] where
Without a '[] = '[]
Without a (a ': xs) = Without a xs
Without a (b ': xs) = b ': (Without a xs)
-- | Either get a value of a given type out of the union or give back a union,
-- asserting that the given type is not a member of this union.
decompose :: Member a as => Union as -> Either (Union (Without a as)) a
decompose unionValue =
case project unionValue of
Nothing -> Left (unsafeCoerce unionValue)
Just x -> Right x
--------------------------------------------------------------------------------
--
-- Pattern matching on open unions (partial and total)
type family IsNonEmpty xs :: Constraint where
IsNonEmpty (a ': b) = ()
IsNonEmpty '[] = TypeError (Text "Type list is empty!")
type family SetEquality xs ys :: Constraint where
SetEquality xs ys = (IsNonEmpty xs, IsNonEmpty ys, Members xs ys, Members ys xs)
infixr 6 :->
data Matches b (xs :: [*]) where
(:->) :: (a -> b) -> Matches b xs -> Matches b (a ': xs)
(:|) :: Matches b '[]
indexIn :: forall a b xs ys . (Member a ys)
=> Matches b (a ': xs) -> Union ys -> Int
indexIn ms u =
let N n = reifyNat :: N (ElemIndexF a ys)
in n
matchTotal :: (SetEquality as xs) => Union as -> Matches b xs -> b
matchTotal u m =
case matchPartial u m of
Nothing -> error "This case is precluded by the type system"
Just x -> x
matchPartial :: (Members xs as) => Union as -> Matches b xs -> Maybe b
matchPartial u (:|) = Nothing
matchPartial u@(Union k v) m@(f :-> ms) =
let n = indexIn m u
in if n == k
then Just (f (unsafeCoerce v))
else matchPartial u ms
demoMatchTotal :: (SetEquality '[A Int, B String] as) => Union as -> Int
demoMatchTotal u = matchTotal u $
(\ (A x) -> x + 1) :->
(\ (B xs) -> length (xs :: String)) :->
(:|)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment