Skip to content

Instantly share code, notes, and snippets.

@harpocrates
Created May 8, 2017 03:20
Show Gist options
  • Save harpocrates/adfccb491f9b1df19426566003c6e1a5 to your computer and use it in GitHub Desktop.
Save harpocrates/adfccb491f9b1df19426566003c6e1a5 to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs, PolyKinds, DataKinds, TypeOperators, FlexibleInstances, FlexibleContexts #-}
import Data.Foldable
import Text.PrettyPrint.HughesPJClass
data Nat = Z | S Nat
type N0 = Z
type N1 = S N0
type N2 = S N1
type N3 = S N2
type N4 = S N3
type N5 = S N4
type N6 = S N5
type N7 = S N6
type N8 = S N7
type N9 = S N8
data Vector (dim :: Nat) a where
Nil :: Vector Z a
(:-) :: a -> Vector n a -> Vector (S n) a
infixr 5 :-
data Tensor (dim :: [Nat]) a where
Scalar :: a -> Tensor '[] a
Tensor :: Vector n (Tensor d a) -> Tensor (n : d) a
instance (Foldable (Vector n), Pretty a) => Pretty (Vector n a) where
pPrint = braces . sep . punctuate (text ",") . map pPrint . toList
instance Pretty a => Pretty (Tensor '[] a) where
pPrint (Scalar x) = pPrint x
instance (Pretty (Tensor ds a), Pretty a, Foldable (Vector d)) => Pretty (Tensor (d : ds) a) where
pPrint (Tensor xs) = pPrint xs
instance (Foldable (Vector n), Show a) => Show (Vector n a) where
showsPrec _ = showList . toList
instance Show a => Show (Tensor '[] a) where
show (Scalar x) = show x
instance (Show (Tensor ds a), Show (Vector d (Tensor ds a)), Show a) => Show (Tensor (d : ds) a) where
show (Tensor xs) = show xs
instance Foldable (Vector Z) where
foldMap f Nil = mempty
instance Foldable (Vector n) => Foldable (Vector (S n)) where
foldMap f (x :- xs) = f x `mappend` foldMap f xs
instance Traversable (Vector Z) where
traverse f Nil = pure Nil
instance Traversable (Vector n) => Traversable (Vector (S n)) where
traverse f (x :- xs) = (:-) <$> f x <*> traverse f xs
instance Functor (Vector Z) where
fmap _ Nil = Nil
instance Functor (Vector n) => Functor (Vector (S n)) where
fmap f (x :- xs) = f x :- fmap f xs
instance Applicative (Vector Z) where
pure _ = Nil
Nil <*> Nil = Nil
instance Applicative (Vector n) => Applicative (Vector (S n)) where
pure x = x :- pure x
(x :- xs) <*> (y :- ys) = x y :- (xs <*> ys)
instance Num a => Num (Vector Z a) where
Nil + Nil = Nil
Nil - Nil = Nil
Nil * Nil = Nil
abs Nil = Nil
signum Nil = Nil
fromInteger n = Nil
instance (Num a, Num (Vector n a)) => Num (Vector (S n) a) where
(x :- xs) + (y :- ys) = (x + y) :- (xs + ys)
(x :- xs) - (y :- ys) = (x - y) :- (xs - ys)
(x :- xs) * (y :- ys) = (x * y) :- (xs * ys)
abs (x :- xs) = abs x :- abs xs
signum (x :- xs) = signum x :- signum xs
fromInteger i = fromInteger i :- fromInteger i
instance Foldable (Tensor '[]) where
foldMap f (Scalar x) = f x
instance (Foldable (Vector d), Foldable (Tensor ds)) => Foldable (Tensor (d : ds)) where
foldMap f (Tensor xs) = foldMap (foldMap f) xs
instance Traversable (Tensor '[]) where
traverse f (Scalar x) = Scalar <$> f x
instance (Traversable (Vector d), Traversable (Tensor ds)) => Traversable (Tensor (d : ds)) where
traverse f (Tensor xs) = Tensor <$> traverse (traverse f) xs
instance Functor (Tensor '[]) where
fmap f (Scalar x) = Scalar (f x)
instance (Functor (Vector d), Functor (Tensor ds)) => Functor (Tensor (d : ds)) where
fmap f (Tensor xs) = Tensor (fmap (fmap f) xs)
instance Applicative (Tensor '[]) where
pure = Scalar
Scalar x <*> Scalar y = Scalar (x y)
instance (Applicative (Vector d), Applicative (Tensor ds)) => Applicative (Tensor (d : ds)) where
pure x = Tensor (pure (pure x))
Tensor xs <*> Tensor ys = Tensor ((<*>) <$> xs <*> ys)
instance Num a => Num (Tensor '[] a) where
Scalar x + Scalar y = Scalar (x + y)
Scalar x - Scalar y = Scalar (x - y)
Scalar x * Scalar y = Scalar (x * y)
abs (Scalar x) = Scalar (abs x)
signum (Scalar x) = Scalar (signum x)
fromInteger i = Scalar (fromInteger i)
instance (Num a, Num (Tensor ds a), Num (Vector d (Tensor ds a))) => Num (Tensor (d : ds) a) where
Tensor xs + Tensor ys = Tensor (xs + ys)
Tensor xs - Tensor ys = Tensor (xs - ys)
Tensor xs * Tensor ys = Tensor (xs * ys)
abs (Tensor xs) = Tensor (abs xs)
signum (Tensor xs) = Tensor (signum xs)
fromInteger i = Tensor (fromInteger i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment