Created
May 8, 2017 03:20
-
-
Save harpocrates/adfccb491f9b1df19426566003c6e1a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs, PolyKinds, DataKinds, TypeOperators, FlexibleInstances, FlexibleContexts #-} | |
import Data.Foldable | |
import Text.PrettyPrint.HughesPJClass | |
data Nat = Z | S Nat | |
type N0 = Z | |
type N1 = S N0 | |
type N2 = S N1 | |
type N3 = S N2 | |
type N4 = S N3 | |
type N5 = S N4 | |
type N6 = S N5 | |
type N7 = S N6 | |
type N8 = S N7 | |
type N9 = S N8 | |
data Vector (dim :: Nat) a where | |
Nil :: Vector Z a | |
(:-) :: a -> Vector n a -> Vector (S n) a | |
infixr 5 :- | |
data Tensor (dim :: [Nat]) a where | |
Scalar :: a -> Tensor '[] a | |
Tensor :: Vector n (Tensor d a) -> Tensor (n : d) a | |
instance (Foldable (Vector n), Pretty a) => Pretty (Vector n a) where | |
pPrint = braces . sep . punctuate (text ",") . map pPrint . toList | |
instance Pretty a => Pretty (Tensor '[] a) where | |
pPrint (Scalar x) = pPrint x | |
instance (Pretty (Tensor ds a), Pretty a, Foldable (Vector d)) => Pretty (Tensor (d : ds) a) where | |
pPrint (Tensor xs) = pPrint xs | |
instance (Foldable (Vector n), Show a) => Show (Vector n a) where | |
showsPrec _ = showList . toList | |
instance Show a => Show (Tensor '[] a) where | |
show (Scalar x) = show x | |
instance (Show (Tensor ds a), Show (Vector d (Tensor ds a)), Show a) => Show (Tensor (d : ds) a) where | |
show (Tensor xs) = show xs | |
instance Foldable (Vector Z) where | |
foldMap f Nil = mempty | |
instance Foldable (Vector n) => Foldable (Vector (S n)) where | |
foldMap f (x :- xs) = f x `mappend` foldMap f xs | |
instance Traversable (Vector Z) where | |
traverse f Nil = pure Nil | |
instance Traversable (Vector n) => Traversable (Vector (S n)) where | |
traverse f (x :- xs) = (:-) <$> f x <*> traverse f xs | |
instance Functor (Vector Z) where | |
fmap _ Nil = Nil | |
instance Functor (Vector n) => Functor (Vector (S n)) where | |
fmap f (x :- xs) = f x :- fmap f xs | |
instance Applicative (Vector Z) where | |
pure _ = Nil | |
Nil <*> Nil = Nil | |
instance Applicative (Vector n) => Applicative (Vector (S n)) where | |
pure x = x :- pure x | |
(x :- xs) <*> (y :- ys) = x y :- (xs <*> ys) | |
instance Num a => Num (Vector Z a) where | |
Nil + Nil = Nil | |
Nil - Nil = Nil | |
Nil * Nil = Nil | |
abs Nil = Nil | |
signum Nil = Nil | |
fromInteger n = Nil | |
instance (Num a, Num (Vector n a)) => Num (Vector (S n) a) where | |
(x :- xs) + (y :- ys) = (x + y) :- (xs + ys) | |
(x :- xs) - (y :- ys) = (x - y) :- (xs - ys) | |
(x :- xs) * (y :- ys) = (x * y) :- (xs * ys) | |
abs (x :- xs) = abs x :- abs xs | |
signum (x :- xs) = signum x :- signum xs | |
fromInteger i = fromInteger i :- fromInteger i | |
instance Foldable (Tensor '[]) where | |
foldMap f (Scalar x) = f x | |
instance (Foldable (Vector d), Foldable (Tensor ds)) => Foldable (Tensor (d : ds)) where | |
foldMap f (Tensor xs) = foldMap (foldMap f) xs | |
instance Traversable (Tensor '[]) where | |
traverse f (Scalar x) = Scalar <$> f x | |
instance (Traversable (Vector d), Traversable (Tensor ds)) => Traversable (Tensor (d : ds)) where | |
traverse f (Tensor xs) = Tensor <$> traverse (traverse f) xs | |
instance Functor (Tensor '[]) where | |
fmap f (Scalar x) = Scalar (f x) | |
instance (Functor (Vector d), Functor (Tensor ds)) => Functor (Tensor (d : ds)) where | |
fmap f (Tensor xs) = Tensor (fmap (fmap f) xs) | |
instance Applicative (Tensor '[]) where | |
pure = Scalar | |
Scalar x <*> Scalar y = Scalar (x y) | |
instance (Applicative (Vector d), Applicative (Tensor ds)) => Applicative (Tensor (d : ds)) where | |
pure x = Tensor (pure (pure x)) | |
Tensor xs <*> Tensor ys = Tensor ((<*>) <$> xs <*> ys) | |
instance Num a => Num (Tensor '[] a) where | |
Scalar x + Scalar y = Scalar (x + y) | |
Scalar x - Scalar y = Scalar (x - y) | |
Scalar x * Scalar y = Scalar (x * y) | |
abs (Scalar x) = Scalar (abs x) | |
signum (Scalar x) = Scalar (signum x) | |
fromInteger i = Scalar (fromInteger i) | |
instance (Num a, Num (Tensor ds a), Num (Vector d (Tensor ds a))) => Num (Tensor (d : ds) a) where | |
Tensor xs + Tensor ys = Tensor (xs + ys) | |
Tensor xs - Tensor ys = Tensor (xs - ys) | |
Tensor xs * Tensor ys = Tensor (xs * ys) | |
abs (Tensor xs) = Tensor (abs xs) | |
signum (Tensor xs) = Tensor (signum xs) | |
fromInteger i = Tensor (fromInteger i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment