Skip to content

Instantly share code, notes, and snippets.

@sjoerdvisscher
Last active December 17, 2015 03:59
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sjoerdvisscher/5547235 to your computer and use it in GitHub Desktop.
Save sjoerdvisscher/5547235 to your computer and use it in GitHub Desktop.
Linear map category and products.
{-# LANGUAGE TypeFamilies, ConstraintKinds, GADTs, TypeOperators, UndecidableInstances #-}
{-# LANGUAGE CPP #-}
import Prelude hiding ((.))
import Data.VectorSpace
import Data.Category
import Data.Category.Limit
type a :* b = (a,b)
type VS0 s = (InnerSpace s, HasZero s, HasScale s, Scalar s ~ s, Num s)
type VS1 s a = (InnerSpace a, HasZero a, HasScale a, Scalar a ~ s, VS0 s)
type VS2 s a b = (VS1 s a , VS1 s b)
type VS3 s a b c = (VS2 s a b, VS1 s c)
class VectorSpace v => HasScale v where
scale :: Scalar v ~ s => s -> LM s v v
idL :: VS1 s v => LM s v v
idL = scale 1
instance VS2 s a b => HasScale (a :* b) where
scale s = scale s *** scale s
class HasZero z where zeroL :: (VS1 s a, VS1 s z) => LM s a z
instance VS2 s a b => HasZero (a :* b) where
zeroL = zeroL &&& zeroL
#define ScalarType(t) \
instance HasZero (t) where { zeroL = Dot zeroV } ; \
instance HasScale (t) where scale = Dot
ScalarType(Int)
ScalarType(Integer)
ScalarType(Float)
ScalarType(Double)
infix 7 :&&
data LM s :: * -> * -> * where
Dot :: VS1 s b =>
b -> LM s b s
(:&&) :: VS3 s a c d =>
LM s a c -> LM s a d -> LM s a (c :* d)
instance VS2 s a b => AdditiveGroup (LM s a b) where
zeroV = zeroL
negateV = (scale (-1) .)
Dot b ^+^ Dot c = Dot (b ^+^ c)
(f :&& g) ^+^ (h :&& k) = (f ^+^ h) &&& (g ^+^ k)
_ ^+^ _ = error "(^+^) for a :-* b: unexpected combination"
-- The last case cannot arise unless pairs are scalars.
instance VS2 s a b => VectorSpace (LM s a b) where
type Scalar (LM s a b) = s
s *^ Dot b = Dot (s *^ b)
s *^ (f :&& g) = s *^ f &&& s *^ g
instance Category (LM s) where
src (Dot _) = idL
src (_ :&& _) = idL
tgt (Dot _) = idL
tgt (_ :&& _) = idL
(f :&& g) . h = f . h &&& g . h
Dot s . Dot b = Dot (s *^ b) -- s must be scalar
Dot ab . (f :&& g) = Dot a . f ^+^ Dot b . g where (a,b) = ab
instance HasBinaryProducts (LM s) where
type BinaryProduct (LM s) a b = (a, b)
proj1 a Dot{} = compFst a
proj1 a (:&&){} = compFst a
proj2 Dot{} b = compSnd b
proj2 (:&&){} b = compSnd b
a@Dot{} &&& b@Dot{} = a :&& b
a@Dot{} &&& b@(:&&){} = a :&& b
a@(:&&){} &&& b@Dot{} = a :&& b
a@(:&&){} &&& b@(:&&){} = a :&& b
-- | @apply (compFst f) == apply f . fst@
compFst :: VS1 s b => LM s a c -> LM s (a :* b) c
compFst (Dot a) = Dot (a,zeroV)
compFst (f :&& g) = compFst f &&& compFst g
-- | @apply (compSnd f) == apply f . snd@
compSnd :: VS1 s a => LM s b c -> LM s (a :* b) c
compSnd (Dot b) = Dot (zeroV,b)
compSnd (f :&& g) = compSnd f &&& compSnd g
fstL :: VS2 s a b => LM s (a :* b) a
fstL = proj1 idL idL
sndL :: VS2 s a b => LM s (a :* b) b
sndL = proj2 idL idL
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment