Skip to content

Instantly share code, notes, and snippets.

@coord-e
Created August 5, 2020 02:11
Show Gist options
  • Save coord-e/7dbcae7e980023edf2942035998bc0d4 to your computer and use it in GitHub Desktop.
Save coord-e/7dbcae7e980023edf2942035998bc0d4 to your computer and use it in GitHub Desktop.
Prove properties of type-level naturals in Haskell
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE NoImplicitPrelude #-}
import Prelude ( IO
, pure
)
import Data.Type.Equality
main :: IO ()
main = pure ()
-- helper
cong :: a :~: b -> f a :~: f b
cong = apply Refl
-- Nat
data Nat
= Z
| S Nat
type Z = 'Z
type S = 'S
-- singleton
data SNat n where
SZ ::SNat Z
SS ::SNat n -> SNat (S n)
class SingN n where
sing :: SNat n
instance SingN Z where
sing = SZ
instance SingN n => SingN (S n) where
sing = SS sing
(+) :: SNat n -> SNat m -> SNat (n + m)
SZ + b = b
SS a' + b = SS (a' + b)
infixl 6 +
(*) :: SNat n -> SNat m -> SNat (n * m)
SZ * _ = SZ
SS a' * b = b + (a' * b)
infixl 7 *
-- operations and axioms
type family a + b where
Z + b = b
S a' + b = S (a' + b)
plus_Z_left :: SNat n -> Z + n :~: n
plus_Z_left _ = Refl
plus_Sn_m :: SNat n -> SNat m -> S n + m :~: S (n + m)
plus_Sn_m _ _ = Refl
type family a * b where
Z * b = Z
S a' * b = b + (a' * b)
mult_Z_left :: SNat n -> Z * n :~: Z
mult_Z_left _ = Refl
mult_Sn_m :: SNat n -> SNat m -> S n * m :~: m + (n * m)
mult_Sn_m _ _ = Refl
-- theorems
plus_Z_right :: SNat a -> a + Z :~: a
plus_Z_right SZ = plus_Z_left SZ
plus_Z_right (SS a') = trans (plus_Sn_m a' SZ) (cong (plus_Z_right a'))
-- Z + S m :~: S (Z + m)
-- ----------------------- sym (cong (plus_Z_left m)) = S m :~: S (Z + m)
-- Z + S m :~: S m
-- ----------------------- plus_Z_left (SS m) = Z + S m :~: S m
-- S m :~: S m
-- S n' + S m :~: S (S n' + m)
-- ----------------------------- plus_Sn_m n' (SS m) = S n' + S m :~: S (n' + S m)
-- S (n' + S m) :~: S (S n' + m)
-- ----------------------------- sym (cong (plus_Sn_m n' m)) = S (S (n' + m)) :~: S (S n' + m)
-- S (n' + S m) :~: S (S (n' + m))
-- ----------------------------- cong (plus_n_Sm n' m) = S (n' + S m) :~: S (S (n' + m))
-- S (S (n' + m)) :~: S (S (n' + m))
plus_n_Sm :: SNat n -> SNat m -> n + S m :~: S (n + m)
plus_n_Sm SZ m = trans (plus_Z_left (SS m)) (sym (cong (plus_Z_left m)))
plus_n_Sm (SS n') m = trans
(trans (plus_Sn_m n' (SS m)) (cong (plus_n_Sm n' m)))
(sym (cong (plus_Sn_m n' m)))
-- h: c :~: d
-- Z + c :~: Z + d
-- ---------------------------- plus_Z_left c = Z + c :~: c
-- c :~: Z + d
-- ---------------------------- sym (plus_Z_left d) = d :~: Z + d
-- c :~: d
-- h1: S a' :~: S b'
-- h2: c :~: d
-- S a' + c :~: S b' + d
-- ---------------------------- plus_Sn_m a' c = S a' + c :~: S (a' + c)
-- S (a' + c) :~: S b' + d
-- ---------------------------- sym (plus_Sn_m b' d) = S (b' + d) :~: S b' + d
-- S (a' + c) :~: S (b' + d)
-- ---------------------------- cong (plus_apply a' b' c d (inner h1) h2) = S (a' + c) :~: S (b' + d)
plus_apply
:: SNat a
-> SNat b
-> SNat c
-> SNat d
-> a :~: b
-> c :~: d
-> a + c :~: b + d
plus_apply SZ SZ c d Refl h =
trans (trans (plus_Z_left c) h) (sym (plus_Z_left d))
plus_apply (SS a') (SS b') c d h1@Refl h2 = trans
(trans (plus_Sn_m a' c) (cong (plus_apply a' b' c d (inner h1) h2)))
(sym (plus_Sn_m b' d))
cong_plus_right :: SNat a -> SNat b -> SNat c -> a :~: b -> a + c :~: b + c
cong_plus_right a b c h = plus_apply a b c c h Refl
cong_plus_left :: SNat a -> SNat b -> SNat c -> a :~: b -> c + a :~: c + b
cong_plus_left a b c h = plus_apply c c a b Refl h
-- Z + (b + c) :~: (Z + b) + c
-- ---------------------------- plus_Z_left (b + c) = Z + (b + c) :~: b + c
-- b + c :~: (Z + b) + c
-- ---------------------------- sym (cong_plus_right (SZ + b) b c (plus_Z_left b)) = b + c :~: (Z + b) + c
-- S a' + (b + c) :~: (S a' + b) + c
-- ---------------------------- sym (plus_Sn_m a' (b + c)) = S a' + (b + c) :~: S (a' + (b + c))
-- S (a' + (b + c)) :~: (S a' + b) + c
-- ---------------------------- cong_plus_right (S (a' + b)) (S a' + b) c (plus_Sn_m a' b) = S (a' + b) + c :~: (S a' + b) + c
-- S (a' + (b + c)) :~: S (a' + b) + c
-- ---------------------------- plus_Sn_m (a' + b) c = S ((a' + b) + c) :~: S (a' + b) + c
-- S (a' + (b + c)) :~: S ((a' + b) + c)
-- ---------------------------- cong (plus_assoc a' b c)
plus_assoc :: SNat a -> SNat b -> SNat c -> a + (b + c) :~: (a + b) + c
plus_assoc SZ b c = trans
(plus_Z_left (b + c))
(sym (cong_plus_right (SZ + b) b c (plus_Z_left b)))
plus_assoc (SS a') b c = trans
(plus_Sn_m a' (b + c))
(trans (trans (cong (plus_assoc a' b c)) (sym (plus_Sn_m (a' + b) c)))
(cong_plus_right (SS (a' + b)) (SS a' + b) c (sym (plus_Sn_m a' b)))
)
-- Z + m :~: m + Z
-- ---------------------------- plus_Z_left m = Z + m :~: m
-- m :~: m + Z
-- ---------------------------- sym (plus_Z_right m) = m :~: m + Z
-- m :~: m
-- S n' + m :~: m + S n'
-- ---------------------------- plus_Sn_m n' m = S n' + m :~: S (n' + m)
-- S (n' + m) :~: m + S n'
-- ---------------------------- sym (plus_n_Sm m n') = m + S n' :~: S (m + n')
-- S (n' + m) :~: S (m + n')
-- ---------------------------- cong (plus_symm n' m)
plus_symm :: SNat n -> SNat m -> n + m :~: m + n
plus_symm SZ m = trans (plus_Z_left m) (sym (plus_Z_right m))
plus_symm (SS n') m =
trans (plus_Sn_m n' m) (trans (cong (plus_symm n' m)) (sym (plus_n_Sm m n')))
mult_Z_right :: SNat n -> n * Z :~: Z
mult_Z_right SZ = mult_Z_left SZ
mult_Z_right (SS n') =
trans (mult_Sn_m n' SZ) (trans (plus_Z_left (n' * SZ)) (mult_Z_right n'))
-- S n + m :~: n + S m
-- ---------------------------- plus_Sn_m n m = S n + m :~: S (n + m)
-- S (n + m) :~: n + S m
-- ---------------------------- plus_n_Sm n m = n + S m :~: S (n + m)
-- S (n + m) :~: S (n + m)
plus_assoc_S :: SNat n -> SNat m -> S n + m :~: n + S m
plus_assoc_S n m = trans (plus_Sn_m n m) (sym (plus_n_Sm n m))
-- Z * S m :~: Z + (Z * m)
-- ---------------------------- mult_Z_left (SS m) = Z * S m :~: Z
-- Z :~: Z + (Z * m)
-- ---------------------------- plus_Z_left (SZ * m) = Z + (Z * m) :~: Z * m
-- Z :~: Z * m
-- ---------------------------- mult_Z_left m = Z * m :~: Z
-- Z :~: Z
-- S n' * S m :~: S n' + (S n' * m)
-- ---------------------------- mult_Sn_m n' (SS m) = S n' * S m :~: S m + (n' * S m)
-- S m + (n' * S m) :~: S n' + (S n' * m)
-- ---------------------------- mult_Sn_m n' m = S n' * m :~: m + (n' * m)
-- S m + (n' * S m) :~: S n' + (m + (n' * m))
-- ---------------------------- plus_assoc (SS n') m (n' * m) = S n' + (m + (n' * m)) :~: (S n' + m) + (n' * m)
-- S m + (n' * S m) :~: (S n' + m) + (n' * m)
-- ---------------------------- plus_assoc_S n' m = S n' + m :~: n' + S m
-- S m + (n' * S m) :~: (n' + S m) + (n' * m)
-- ---------------------------- plus_symm n' (SS m) = n' + S m :~: S m + n'
-- S m + (n' * S m) :~: (S m + n') + (n' * m)
-- ---------------------------- plus_assoc (SS m) n' (n' * m) = S m + (n' + (n' * m)) :~: (S m + n') + (n' * m)
-- S m + (n' * S m) :~: S m + (n' + (n' * m))
-- ---------------------------- cong_plus_left (n' * SS m) (n' + (n' * m)) (SS m) (mult_n_Sm n' m)
mult_n_Sm :: SNat n -> SNat m -> n * S m :~: n + (n * m)
mult_n_Sm SZ m = trans (mult_Z_left (SS m))
(sym (trans (plus_Z_left (SZ * m)) (mult_Z_left m)))
mult_n_Sm (SS n') m = trans
(trans
(trans
(mult_Sn_m n' (SS m))
(cong_plus_left (n' * SS m) (n' + (n' * m)) (SS m) (mult_n_Sm n' m))
)
(trans
(plus_assoc (SS m) n' (n' * m))
(sym
(cong_plus_right (SS n' + m)
(SS m + n')
(n' * m)
(trans (plus_assoc_S n' m) (plus_symm n' (SS m)))
)
)
)
)
(sym
(trans
(cong_plus_left (SS n' * m) (m + (n' * m)) (SS n') (mult_Sn_m n' m))
(plus_assoc (SS n') m (n' * m))
)
)
-- Z * m :~: m * Z
-- ---------------------------- mult_Z_left m = Z * m :~: Z
-- Z :~: m * Z
-- ---------------------------- mult_Z_right m = m * Z :~: Z
-- Z :~: Z
-- S n' * m :~: m * S n'
-- ---------------------------- mult_Sn_m n' m = S n' * m :~: m + (n' * m)
-- m + (n' * m) :~: m * S n'
-- ---------------------------- sym (mult_n_Sm m n') = m + (m * n') :~: m * S n'
-- m + (n' * m) :~: m + (m * n')
-- ---------------------------- cong_plus_left (n' * m) (m * n') m (mult_symm n' m) = m + (n' * m) :~: m + (m * n')
mult_symm :: SNat n -> SNat m -> n * m :~: m * n
mult_symm SZ m = trans (mult_Z_left m) (sym (mult_Z_right m))
mult_symm (SS n') m = trans
(mult_Sn_m n' m)
(trans (cong_plus_left (n' * m) (m * n') m (mult_symm n' m))
(sym (mult_n_Sm m n'))
)
-- h: a :~: b
-- a * Z :~: b * Z
-- ---------------------------- mult_Z_right a = a * Z :~: Z
-- Z :~: b * Z
-- ---------------------------- sym (mult_Z_right b) = Z :~: b * Z
-- Z :~: Z
-- h: a :~: b
-- a * S c' :~: b * S c'
-- ---------------------------- mult_n_Sm a c' = a * S c' :~: a + (a * c')
-- a + (a * c') :~: b * S c'
-- ---------------------------- sym (mult_n_Sm b c') = b + (b * c') :~: b * S c'
-- a + (a * c') :~: b + (b * c')
-- ---------------------------- plus_apply a b (a * c') (b * c') h (cong_mult_right a b c' h) = a + (a * c') :~: b + (b * c')
cong_mult_right :: SNat a -> SNat b -> SNat c -> a :~: b -> a * c :~: b * c
cong_mult_right a b SZ _ = trans (mult_Z_right a) (sym (mult_Z_right b))
cong_mult_right a b (SS c') h = trans
(trans (mult_n_Sm a c')
(plus_apply a b (a * c') (b * c') h (cong_mult_right a b c' h))
)
(sym (mult_n_Sm b c'))
-- (a + b) * Z :~: (a * Z) + (b * Z)
-- ---------------------------- mult_Z_right (a + b) = (a + b) * Z :~: Z
-- Z :~: (a * Z) + (b * Z)
-- ---------------------------- cong_plus_right (a * SZ) SZ (b * SZ) (mult_Z_right a) = (a * Z) + (b * Z) :~: Z + (b * Z)
-- Z :~: Z + (b * Z)
-- ---------------------------- plus_Z_left (b * SZ) = Z + (b * Z) :~: b * Z
-- Z :~: b * Z
-- ---------------------------- mult_Z_right b = b * Z :~: Z
-- Z :~: Z
-- (a + b) * S c' :~: (a * S c') + (b * S c')
-- ---------------------------- mult_n_Sm (a + b) c' = (a + b) * SS c' :~: (a + b) + ((a + b) * c')
-- (a + b) + ((a + b) * c') :~: (a * S c') + (b * S c')
-- ---------------------------- cong_plus_right (a * S c') (a + (a * c')) (b * SS c') (mult_n_Sm a c') = (a * S c') + (b * S c') :~: (a + (a * c')) + (b * S c')
-- (a + b) + ((a + b) * c') :~: (a + (a * c')) + (b * S c')
-- ---------------------------- cong_plus_left (b * S c') (b + (b * c')) (a + (a * c')) (mult_n_Sm b c') = (a + (a * c')) + (b * S c') :~: (a + (a * c')) + (b + (b * c'))
-- (a + b) + ((a + b) * c') :~: (a + (a * c')) + (b + (b * c'))
-- ---------------------------- sym (plus_assoc a (a * c') (b + (b * c'))) = (a + (a * c')) + (b + (b * c')) :~: a + ((a * c') + (b + (b * c')))
-- (a + b) + ((a + b) * c') :~: a + ((a * c') + (b + (b * c')))
-- ---------------------------- cong_plus_left ((a * c') + (b + (b * c'))) (((a * c') + b) + (b * c')) a (plus_assoc (a * c') b (b * c')) = a + ((a * c') + (b + (b * c'))) :~: a + (((a * c') + b) + (b * c'))
-- (a + b) + ((a + b) * c') :~: a + (((a * c') + b) + (b * c'))
-- ---------------------------- cong_plus_left (((a * c') + b) + (b * c')) ((b + (a * c')) + (b * c')) a (cong_plus_right ((a * c') + b) (b + (a * c')) (b * c') (plus_symm (a * c') b)) = a + (((a * c') + b) + (b * c')) :~: a + ((b + (a * c')) + (b * c'))
-- (a + b) + ((a + b) * c') :~: a + ((b + (a * c')) + (b * c'))
-- ---------------------------- sym (cong_plus_left (b + ((a * c') + (b * c'))) ((b + (a * c')) + (b * c')) a (plus_assoc b (a * c') (b * c'))) = a + ((b + (a * c')) + (b * c')) :~: a + (b + ((a * c') + (b * c')))
-- (a + b) + ((a + b) * c') :~: a + (b + ((a * c') + (b * c')))
-- ---------------------------- plus_assoc a b ((a * c') + (b * c')) = a + (b + ((a * c') + (b * c'))) :~: (a + b) + ((a * c') + (b * c'))
-- (a + b) + ((a + b) * c') :~: (a + b) + ((a * c') + (b * c'))
-- ---------------------------- cong_plus_left ((a + b) * c') ((a * c') + (b * c')) (a + b) (distr_plus_mult a b c') = (a + b) + ((a + b) * c') :~: (a + b) + ((a * c') + (b * c'))
distr_plus_mult
:: SNat a -> SNat b -> SNat c -> (a + b) * c :~: (a * c) + (b * c)
distr_plus_mult a b SZ = trans
(mult_Z_right (a + b))
(sym
(trans (cong_plus_right (a * SZ) SZ (b * SZ) (mult_Z_right a))
(trans (plus_Z_left (b * SZ)) (mult_Z_right b))
)
)
distr_plus_mult a b (SS c') = trans
(trans
(mult_n_Sm (a + b) c')
(cong_plus_left ((a + b) * c')
((a * c') + (b * c'))
(a + b)
(distr_plus_mult a b c')
)
)
(sym
(trans
(trans
(trans
(trans
(trans
(trans
(cong_plus_right (a * SS c')
(a + (a * c'))
(b * SS c')
(mult_n_Sm a c')
)
(cong_plus_left (b * SS c')
(b + (b * c'))
(a + (a * c'))
(mult_n_Sm b c')
)
)
(sym (plus_assoc a (a * c') (b + (b * c'))))
)
(cong_plus_left ((a * c') + (b + (b * c')))
(((a * c') + b) + (b * c'))
a
(plus_assoc (a * c') b (b * c'))
)
)
(cong_plus_left
(((a * c') + b) + (b * c'))
((b + (a * c')) + (b * c'))
a
(cong_plus_right ((a * c') + b)
(b + (a * c'))
(b * c')
(plus_symm (a * c') b)
)
)
)
(sym
(cong_plus_left (b + ((a * c') + (b * c')))
((b + (a * c')) + (b * c'))
a
(plus_assoc b (a * c') (b * c'))
)
)
)
(plus_assoc a b ((a * c') + (b * c')))
)
)
-- Z * (b * c) :~: (Z * b) * c
-- ---------------------------- mult_Z_left (b * c) = Z * (b * c) :~: Z
-- Z :~: (Z * b) * c
-- ---------------------------- cong_mult_right SZ (SZ * b) c (sym (mult_Z_left b)) = Z * c :~: (Z * b) * c
-- Z :~: Z * c
-- ---------------------------- sym (mult_Z_left c) = Z :~: Z * c
-- S a' * (b * c) :~: (S a' * b) * c
-- ---------------------------- mult_Sn_m a' (b * c) = S a' * (b * c) :~: (b * c) + (a' * (b * c))
-- (b * c) + (a' * (b * c)) :~: (S a' * b) * c
-- ---------------------------- cong_mult_right (S a' * b) (b + (a' * b)) c (mult_Sn_m a' b) = (S a' * b) * c :~: (b + (a' * b)) * c
-- (b * c) + (a' * (b * c)) :~: (b + (a' * b)) * c
-- ---------------------------- distr_plus_mult b (a' * b) c = (b + (a' * b)) * c :~: (b * c) + ((a' * b) * c)
-- (b * c) + (a' * (b * c)) :~: (b * c) + ((a' * b) * c)
-- ---------------------------- cong_plus_left (a' * (b * c)) ((a' * b) * c) (b * c) (mult_assoc a' b c) = (b * c) + (a' * (b * c)) :~: (b * c) + ((a' * b) * c)
mult_assoc :: SNat a -> SNat b -> SNat c -> a * (b * c) :~: (a * b) * c
mult_assoc SZ b c = trans
(trans (mult_Z_left (b * c)) (sym (mult_Z_left c)))
(cong_mult_right SZ (SZ * b) c (sym (mult_Z_left b)))
mult_assoc (SS a') b c = trans
(trans
(mult_Sn_m a' (b * c))
(cong_plus_left (a' * (b * c)) ((a' * b) * c) (b * c) (mult_assoc a' b c))
)
(sym
(trans (cong_mult_right (SS a' * b) (b + (a' * b)) c (mult_Sn_m a' b))
(distr_plus_mult b (a' * b) c)
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment