Skip to content

Instantly share code, notes, and snippets.

@rampion
Created August 8, 2020 20:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rampion/58b5e35948615295abba20a5e3f56c4b to your computer and use it in GitHub Desktop.
Save rampion/58b5e35948615295abba20a5e3f56c4b to your computer and use it in GitHub Desktop.
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module BlockMatrix
( Block(..)
, at, Fin(..)
, height, width
) where
import GHC.TypeLits
import Data.Proxy
import Data.Type.Equality
import Unsafe.Coerce
import Control.Applicative (liftA2)
data Block (height :: Nat) (width :: Nat) a where
Constant :: a -> Block height width a
Vertical :: KnownNat top => Block top width a -> Block bottom width a -> Block (top + bottom) width a
Horizontal :: KnownNat left => Block height left a -> Block height right a -> Block height (left + right) a
deriving instance Functor (Block height width)
deriving instance Foldable (Block height width)
deriving instance Traversable (Block height width)
instance Applicative (Block height width) where
pure = Constant
Constant f <*> ba = fmap f ba
Vertical bf bf' <*> (vsplit (height bf) (height bf') -> (ba, ba')) = Vertical (bf <*> ba) (bf' <*> ba')
Horizontal bf bf' <*> (hsplit (width bf) (width bf') -> (ba, ba')) = Horizontal (bf <*> ba) (bf' <*> ba')
-- | element-wise operations
instance Num a => Num (Block height width a) where
(+) = liftA2 (+)
(-) = liftA2 (-)
(*) = liftA2 (*)
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
-- | element addressing
at :: Block height width a -> (Fin height, Fin width) -> a
at (Constant a) _ = a
at (Vertical top bottom) (row, col) = case finiteSplit row (height top) of
Left row -> at top (row, col)
Right row -> at bottom (row, col)
at (Horizontal left right) (row, col) = case finiteSplit col (width left) of
Left col -> at left (row, col)
Right col -> at right (row, col)
data Fin (n :: Nat) where
Fin :: (CmpNat m n ~ 'LT, KnownNat m) => Proxy m -> Fin n
height :: Block height width a -> Proxy height
height _ = Proxy
width :: Block height width a -> Proxy width
width _ = Proxy
vsplit :: KnownNat m => proxy0 m -> proxy1 n -> Block (m + n) width a -> (Block m width a, Block n width a)
vsplit _ _ (Constant a) = (Constant a, Constant a)
vsplit m n (Vertical top bottom) =
case sassociate m n (height top) (height bottom) of
SLessThan p ->
let (mtop, ptop) = vsplit m p top
in (mtop, Vertical ptop bottom)
SEqual -> (top, bottom)
SGreaterThan p ->
let (pbottom, nbottom) = vsplit p n bottom
in (Vertical top pbottom, nbottom)
vsplit m n (Horizontal left right) =
let (mleft, nleft) = vsplit m n left
(mright, nright) = vsplit m n right
in (Horizontal mleft mright, Horizontal nleft nright)
hsplit :: KnownNat m => proxy0 m -> proxy1 n -> Block height (m + n) a -> (Block height m a, Block height n a)
hsplit _ _ (Constant a) = (Constant a, Constant a)
hsplit m n (Vertical top bottom) =
let (mtop, ntop) = hsplit m n top
(mbottom, nbottom) = hsplit m n bottom
in (Vertical mtop mbottom, Vertical ntop nbottom)
hsplit m n (Horizontal left right) =
case sassociate m n (width left) (width right) of
SLessThan p ->
let (mleft, pleft) = hsplit m p left
in (mleft, Horizontal pleft right)
SEqual -> (left, right)
SGreaterThan p ->
let (pright, nright) = hsplit p n right
in (Horizontal left pright, nright)
-- | Given (m + n) ~ (m' + n'), either
-- - ∃ p > 0 s.t. m' = m + p and n = p + n'
-- - ∃ p > 0 s.t. m = m' + p and n' = p + n
-- - m = m' and n = n'
sassociate :: (KnownNat m, KnownNat m', (m + n) ~ (m' + n'))
=> proxy0 m -> proxy1 n -> proxy2 m' -> proxy3 n' -> SAssociate m n m' n'
sassociate pm _ pm' _ =
let m = natVal pm
m' = natVal pm'
in case (someNatVal (m' - m), someNatVal (m - m')) of
(Just (SomeNat pp), Nothing) -> unsafeCoerce (sLessThan pp)
(Nothing, Just (SomeNat pp)) -> unsafeCoerce (sGreaterThan pp)
(Just _, Just _) -> unsafeCoerce sEqual
(Nothing, Nothing) -> error "impossible: both (m' - m) and (m - m') are negative"
data SAssociate m n m' n' where
SLessThan :: KnownNat p => Proxy p -> SAssociate m (p + n') (m + p) n'
SEqual :: SAssociate m n m n
SGreaterThan :: KnownNat p => Proxy p -> SAssociate (m' + p) n m' (p + n)
sLessThan :: KnownNat n => Proxy n -> SAssociate 0 n n 0
sLessThan = SLessThan
sEqual :: SAssociate 0 0 0 0
sEqual = SEqual
sGreaterThan :: KnownNat m => Proxy m -> SAssociate m 0 0 m
sGreaterThan = SGreaterThan
-- | Given p < m + n, then either p < m or p - m < n
finiteSplit :: KnownNat m => Fin (m + n) -> proxy m -> Either (Fin m) (Fin n)
finiteSplit (Fin pp) (pm)
| natVal pp < natVal pm = Left (unsafeCoerce (succFin pp))
| otherwise = case someNatVal (natVal pp - natVal pm) of
Just (SomeNat pq) -> Right (unsafeCoerce (succFin pq))
Nothing -> error "impossible: pp >= pm, but pp - pm is negative"
succLT :: proxy m -> CmpNat m (m + 1) :~: 'LT
succLT = unsafeCoerce Refl
succFin :: KnownNat m => Proxy m -> Fin (m + 1)
succFin pm = case succLT pm of Refl -> Fin pm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment