Created
August 8, 2020 20:34
-
-
Save rampion/58b5e35948615295abba20a5e3f56c4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wno-name-shadowing #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE ViewPatterns #-} | |
module BlockMatrix | |
( Block(..) | |
, at, Fin(..) | |
, height, width | |
) where | |
import GHC.TypeLits | |
import Data.Proxy | |
import Data.Type.Equality | |
import Unsafe.Coerce | |
import Control.Applicative (liftA2) | |
data Block (height :: Nat) (width :: Nat) a where | |
Constant :: a -> Block height width a | |
Vertical :: KnownNat top => Block top width a -> Block bottom width a -> Block (top + bottom) width a | |
Horizontal :: KnownNat left => Block height left a -> Block height right a -> Block height (left + right) a | |
deriving instance Functor (Block height width) | |
deriving instance Foldable (Block height width) | |
deriving instance Traversable (Block height width) | |
instance Applicative (Block height width) where | |
pure = Constant | |
Constant f <*> ba = fmap f ba | |
Vertical bf bf' <*> (vsplit (height bf) (height bf') -> (ba, ba')) = Vertical (bf <*> ba) (bf' <*> ba') | |
Horizontal bf bf' <*> (hsplit (width bf) (width bf') -> (ba, ba')) = Horizontal (bf <*> ba) (bf' <*> ba') | |
-- | element-wise operations | |
instance Num a => Num (Block height width a) where | |
(+) = liftA2 (+) | |
(-) = liftA2 (-) | |
(*) = liftA2 (*) | |
negate = fmap negate | |
abs = fmap abs | |
signum = fmap signum | |
fromInteger = pure . fromInteger | |
-- | element addressing | |
at :: Block height width a -> (Fin height, Fin width) -> a | |
at (Constant a) _ = a | |
at (Vertical top bottom) (row, col) = case finiteSplit row (height top) of | |
Left row -> at top (row, col) | |
Right row -> at bottom (row, col) | |
at (Horizontal left right) (row, col) = case finiteSplit col (width left) of | |
Left col -> at left (row, col) | |
Right col -> at right (row, col) | |
data Fin (n :: Nat) where | |
Fin :: (CmpNat m n ~ 'LT, KnownNat m) => Proxy m -> Fin n | |
height :: Block height width a -> Proxy height | |
height _ = Proxy | |
width :: Block height width a -> Proxy width | |
width _ = Proxy | |
vsplit :: KnownNat m => proxy0 m -> proxy1 n -> Block (m + n) width a -> (Block m width a, Block n width a) | |
vsplit _ _ (Constant a) = (Constant a, Constant a) | |
vsplit m n (Vertical top bottom) = | |
case sassociate m n (height top) (height bottom) of | |
SLessThan p -> | |
let (mtop, ptop) = vsplit m p top | |
in (mtop, Vertical ptop bottom) | |
SEqual -> (top, bottom) | |
SGreaterThan p -> | |
let (pbottom, nbottom) = vsplit p n bottom | |
in (Vertical top pbottom, nbottom) | |
vsplit m n (Horizontal left right) = | |
let (mleft, nleft) = vsplit m n left | |
(mright, nright) = vsplit m n right | |
in (Horizontal mleft mright, Horizontal nleft nright) | |
hsplit :: KnownNat m => proxy0 m -> proxy1 n -> Block height (m + n) a -> (Block height m a, Block height n a) | |
hsplit _ _ (Constant a) = (Constant a, Constant a) | |
hsplit m n (Vertical top bottom) = | |
let (mtop, ntop) = hsplit m n top | |
(mbottom, nbottom) = hsplit m n bottom | |
in (Vertical mtop mbottom, Vertical ntop nbottom) | |
hsplit m n (Horizontal left right) = | |
case sassociate m n (width left) (width right) of | |
SLessThan p -> | |
let (mleft, pleft) = hsplit m p left | |
in (mleft, Horizontal pleft right) | |
SEqual -> (left, right) | |
SGreaterThan p -> | |
let (pright, nright) = hsplit p n right | |
in (Horizontal left pright, nright) | |
-- | Given (m + n) ~ (m' + n'), either | |
-- - ∃ p > 0 s.t. m' = m + p and n = p + n' | |
-- - ∃ p > 0 s.t. m = m' + p and n' = p + n | |
-- - m = m' and n = n' | |
sassociate :: (KnownNat m, KnownNat m', (m + n) ~ (m' + n')) | |
=> proxy0 m -> proxy1 n -> proxy2 m' -> proxy3 n' -> SAssociate m n m' n' | |
sassociate pm _ pm' _ = | |
let m = natVal pm | |
m' = natVal pm' | |
in case (someNatVal (m' - m), someNatVal (m - m')) of | |
(Just (SomeNat pp), Nothing) -> unsafeCoerce (sLessThan pp) | |
(Nothing, Just (SomeNat pp)) -> unsafeCoerce (sGreaterThan pp) | |
(Just _, Just _) -> unsafeCoerce sEqual | |
(Nothing, Nothing) -> error "impossible: both (m' - m) and (m - m') are negative" | |
data SAssociate m n m' n' where | |
SLessThan :: KnownNat p => Proxy p -> SAssociate m (p + n') (m + p) n' | |
SEqual :: SAssociate m n m n | |
SGreaterThan :: KnownNat p => Proxy p -> SAssociate (m' + p) n m' (p + n) | |
sLessThan :: KnownNat n => Proxy n -> SAssociate 0 n n 0 | |
sLessThan = SLessThan | |
sEqual :: SAssociate 0 0 0 0 | |
sEqual = SEqual | |
sGreaterThan :: KnownNat m => Proxy m -> SAssociate m 0 0 m | |
sGreaterThan = SGreaterThan | |
-- | Given p < m + n, then either p < m or p - m < n | |
finiteSplit :: KnownNat m => Fin (m + n) -> proxy m -> Either (Fin m) (Fin n) | |
finiteSplit (Fin pp) (pm) | |
| natVal pp < natVal pm = Left (unsafeCoerce (succFin pp)) | |
| otherwise = case someNatVal (natVal pp - natVal pm) of | |
Just (SomeNat pq) -> Right (unsafeCoerce (succFin pq)) | |
Nothing -> error "impossible: pp >= pm, but pp - pm is negative" | |
succLT :: proxy m -> CmpNat m (m + 1) :~: 'LT | |
succLT = unsafeCoerce Refl | |
succFin :: KnownNat m => Proxy m -> Fin (m + 1) | |
succFin pm = case succLT pm of Refl -> Fin pm |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment