Last active
December 17, 2015 04:59
-
-
Save sjoerdvisscher/5554910 to your computer and use it in GitHub Desktop.
Normal form for categories
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeFamilies, ConstraintKinds, GADTs, TypeOperators, RankNTypes, MultiParamTypeClasses, ScopedTypeVariables, UndecidableInstances #-} | |
{-# LANGUAGE CPP #-} | |
import Prelude hiding (id, (.), fst, snd) | |
import Data.VectorSpace | |
import Control.Category | |
import Control.Arrow ((&&&)) | |
{-------------------------------------------------------------------- | |
Misc | |
--------------------------------------------------------------------} | |
type a :* b = (a,b) | |
{-------------------------------------------------------------------- | |
General linear transformations | |
--------------------------------------------------------------------} | |
type VS a = (InnerSpace a, HasZero a, HasScale a, Num (Scalar a)) | |
type a :~ b = Scalar a ~ Scalar b | |
type CV b a = (VS b, b :~ a) -- compatible vector space | |
type VS2 a b = (VS a , CV b a) | |
type VS3 a b c = (VS2 a b , CV c a) | |
type VS4 a b c d = (VS3 a b c, CV d a) | |
class VectorSpace v => HasScale v where | |
scale :: Scalar v -> v :-* v | |
idL :: (HasScale v, Num (Scalar v)) => v :-* v | |
idL = scale 1 | |
instance VS2 a b => HasScale (a :* b) where | |
scale s = compFst (scale s) :&& compSnd (scale s) | |
class HasZero z where zeroL :: CV a z => a :-* z | |
instance VS2 a b => HasZero (a :* b) where | |
zeroL = zeroL :&& zeroL | |
-- Variable occurs more often in a constraint than in the instance head | |
-- in the constraint: VS a | |
-- (Use -XUndecidableInstances to permit this) | |
-- In the instance declaration for `HasScale (a, b)' | |
#define ScalarType(t) \ | |
instance HasZero (t) where { zeroL = Dot zeroV } ; \ | |
instance HasScale (t) where scale = Dot | |
ScalarType(Int) | |
ScalarType(Integer) | |
ScalarType(Float) | |
ScalarType(Double) | |
infix 7 :&& | |
infixr 1 :-* | |
-- | Linear transformation | |
data (:-*) :: * -> * -> * where | |
Dot :: InnerSpace b => | |
b -> b :-* Scalar b | |
(:&&) :: VS3 a c d => | |
a :-* c -> a :-* d -> a :-* c :* d | |
-- | Semantic function: sample a linear transformation | |
apply :: a :-* b -> a -> b | |
apply (Dot b) = dot b | |
apply (f :&& g) = apply f &&& apply g | |
dot :: InnerSpace b => b -> b -> Scalar b | |
dot = (<.>) | |
instance VS2 a b => AdditiveGroup (a :-* b) where | |
zeroV = zeroL | |
negateV = (scale (-1) `compL`) | |
Dot b ^+^ Dot c = Dot (b ^+^ c) | |
(f :&& g) ^+^ (h :&& k) = (f ^+^ h) :&& (g ^+^ k) | |
_ ^+^ _ = error "(^+^) for a :-* b: unexpected combination" | |
-- The last case cannot arise unless pairs are scalars. | |
instance VS2 a b => VectorSpace (a :-* b) where | |
type Scalar (a :-* b) = Scalar b | |
s *^ Dot b = Dot (s *^ b) | |
s *^ (f :&& g) = (s *^ f) :&& (s *^ g) | |
-- InnerSpace instance? | |
-- | @apply (compFst f) == apply f . fst@ | |
compFst :: VS3 a b c => a :-* c -> a :* b :-* c | |
compFst (Dot a) = Dot (a,zeroV) | |
compFst (f :&& g) = compFst f :&& compFst g | |
-- dot a . fst = dot (a,0) | |
-- | |
-- (f &&& g) . fst = f . fst &&& g . fst | |
-- | @apply (compSnd f) == apply f . snd@ | |
compSnd :: VS3 a b c => b :-* c -> a :* b :-* c | |
compSnd (Dot b) = Dot (zeroV,b) | |
compSnd (f :&& g) = compSnd f :&& compSnd g | |
fstL :: VS2 a b => a :* b :-* a | |
fstL = compFst idL | |
sndL :: VS2 a b => a :* b :-* b | |
sndL = compSnd idL | |
data NC k t a b where | |
Id :: NC k t a a | |
Comp :: k b c => t b c -> NC k t a b -> NC k t a c | |
liftNC :: k a b => t a b -> NC k t a b | |
liftNC t = Comp t Id | |
foldNC :: forall a b k t r. r a -> (forall c d. k c d => t c d -> r c -> r d) -> NC k t a b -> r b | |
foldNC idt comp = foldNC' | |
where | |
foldNC' :: NC k t a c -> r c | |
foldNC' Id = idt | |
foldNC' (Comp l r) = comp l (foldNC' r) | |
instance Category (NC k t) where | |
id = Id | |
Id . a = a | |
Comp lm lmc . a = Comp lm (lmc . a) | |
instance VS2 a b => AdditiveGroup (LM a b) where | |
zeroV = liftNC zeroL | |
negateV = (scale (-1) `Comp`) | |
a ^+^ b = liftNC (lowerLM a ^+^ lowerLM b) | |
instance VS2 a b => VectorSpace (LM a b) where | |
type Scalar (LM a b) = Scalar b | |
s *^ a = scale s `Comp` a | |
class VS3 a b c => VS3C a b c where | |
instance VS3 a b c => VS3C a b c where | |
type LM a = NC (VS3C a) (:-*) a | |
lowerLM :: VS2 a b => LM a b -> a :-* b | |
lowerLM = foldNC idL compL | |
compL :: VS3 a b c => b :-* c -> a :-* b -> a :-* c | |
(f :&& g) `compL` h = f `compL` h :&& g `compL` h | |
Dot s `compL` Dot b = Dot (s *^ b) -- s must be scalar | |
Dot ab `compL` (f :&& g) = Dot a `compL` f ^+^ Dot b `compL` g where (a,b) = ab | |
-- The GHC 7.4.1 type-checker balks at the Dot (a,b) pattern, so I used a where. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment