Created
April 26, 2017 00:04
-
-
Save Lysxia/bf4fa5adf879de309bca780a5abceff8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Tensor where | |
import Data.Distributive | |
import Data.Functor.Rep | |
import Data.List | |
import Data.Maybe | |
import Data.Singletons.Prelude | |
import Data.Promotion.Prelude | |
import GHC.Exts | |
import Data.Kind | |
import GHC.Show | |
import GHC.TypeLits | |
import qualified Data.Vector as V | |
-- $setup | |
-- >>> :set -XDataKinds | |
-- >>> :set -XOverloadedLists | |
-- >>> :set -XTypeFamilies | |
-- >>> let a = [1..24] :: Tensor '[2,3,4] Int | |
-- >>> let v = [1,2,3] :: Tensor '[3] Int | |
-- | an n-dimensional array where shape is specified at the type level The main | |
-- purpose of this, beyond safe typing, is to supply the Representable instance | |
-- with an initial object. A single Boxed 'Data.Vector.Vector' is used | |
-- underneath for efficient slicing, but this may change or become polymorphic | |
-- in the future. | |
-- >>> a | |
-- [[[1, 2, 3, 4], | |
-- [5, 6, 7, 8], | |
-- [9, 10, 11, 12]], | |
-- [[13, 14, 15, 16], | |
-- [17, 18, 19, 20], | |
-- [21, 22, 23, 24]]] | |
newtype Tensor (r::[Nat]) a = Tensor { flattenTensor :: V.Vector a } | |
deriving (Functor, Eq, Foldable) | |
data SomeTensor a = SomeTensor [Int] (V.Vector a) | |
deriving (Functor, Eq, Foldable) | |
-- | convert a 'Tensor' to a 'SomeTensor', losing the type level shape | |
someTensor :: (SingI r) => Tensor (r::[Nat]) a -> SomeTensor a | |
someTensor n = SomeTensor (shape n) (flattenTensor n) | |
-- | extracts shape from the type level | |
-- >>> shape a | |
-- [2,3,4] | |
shape :: forall (r::[Nat]) a. (SingI r) => Tensor r a -> [Int] | |
shape _ = fmap fromIntegral (fromSing (sing :: Sing r)) | |
-- | convert from n-dim shape index to a flat index | |
-- >>> ind [2,3,4] [1,1,1] | |
-- 17 | |
ind :: [Int] -> [Int] -> Int | |
ind ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) 1 ns) | |
unfoldI :: forall t. Integral t => [t] -> t -> ([t], t) | |
unfoldI ns x = | |
foldr | |
(\a (acc, r) -> let (d,m) = divMod r a in (m:acc,d)) | |
([],x) | |
ns | |
-- | convert from a flat index to a shape index | |
-- >>> unind [2,3,4] 17 | |
-- [1,1,1] | |
unind :: [Int] -> Int -> [Int] | |
unind ns x= fst $ unfoldI ns x | |
instance forall r. (SingI r) => Distributive (Tensor (r::[Nat])) where | |
distribute f = Tensor $ V.generate n | |
$ \i -> fmap (\(Tensor v) -> V.unsafeIndex v i) f | |
where | |
n = case (sing :: Sing r) of | |
SNil -> 1 | |
(SCons x xs) -> product $ fromInteger <$> (fromSing x: fromSing xs) | |
instance forall (r :: [Nat]). (SingI r) => Representable (Tensor r) where | |
type Rep (Tensor r) = [Int] | |
tabulate f = Tensor $ V.generate (product ns) (f . unind ns) | |
where | |
ns = case (sing :: Sing r) of | |
SNil -> [] | |
(SCons x xs) -> fromIntegral <$> (fromSing x: fromSing xs) | |
index (Tensor xs) rs = xs V.! ind ns rs | |
where | |
ns = case (sing :: Sing r) of | |
SNil -> [] | |
(SCons x xs') -> fromIntegral <$> (fromSing x: fromSing xs') | |
-- | from flat list | |
instance (SingI r, Num a) => IsList (Tensor (r::[Nat]) a) where | |
type Item (Tensor r a) = a | |
fromList l = Tensor $ V.fromList $ take n $ l ++ repeat 0 | |
where | |
n = case (sing :: Sing r) of | |
SNil -> 1 | |
(SCons x xs') -> product $ fromIntegral <$> (fromSing x: fromSing xs') | |
toList (Tensor v) = V.toList v | |
instance (Show a) => Show (SomeTensor a) where | |
show r@(SomeTensor l _) = go (length l) r | |
where | |
go n r'@(SomeTensor l' v') = case length l' of | |
0 -> show $ V.head v' | |
1 -> "[" ++ intercalate ", " (show <$> GHC.Exts.toList v') ++ "]" | |
x -> | |
"[" ++ | |
intercalate | |
(",\n" ++ replicate (n-x+1) ' ') | |
(go n <$> flatten1 r') ++ | |
"]" | |
-- | convert the top layer of a SomeTensor to a [SomeTensor] | |
flatten1 :: SomeTensor a -> [SomeTensor a] | |
flatten1 (SomeTensor rep v) = (\s -> SomeTensor (drop 1 rep) (V.unsafeSlice (s*l) l v)) <$> ss | |
where | |
(n, l) = case rep of | |
[] -> (0, 1) | |
n : r -> (n, product r) | |
ss = take n [0..] | |
instance (Show a, SingI r) => Show (Tensor (r::[Nat]) a) where | |
show = show . someTensor | |
-- ** Operations | |
-- | inner product | |
-- >>> v <.> v | |
-- 14 | |
(<.>) :: (Num a, Foldable m, Representable m) => m a -> m a -> a | |
(<.>) a b = sum $ liftR2 (*) a b | |
-- | outer product | |
-- >>> v >< v | |
-- [[1, 2, 3], | |
-- [2, 4, 6], | |
-- [3, 6, 9]] | |
(><) | |
:: forall (r::[Nat]) (s::[Nat]) a | |
. (Num a, SingI r, SingI s, SingI (r :++ s)) | |
=> Tensor r a -> Tensor s a -> Tensor (r :++ s) a | |
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i)) | |
where | |
dimm = length (shape m) | |
-- | | |
-- | |
-- >>> let a = [1, 2, 3, 4] :: Tensor '[2, 2] Int | |
-- >>> let b = [5, 6, 7, 8] :: Tensor '[2, 2] Int | |
-- >>> a | |
-- [[1, 2], | |
-- [3, 4]] | |
-- >>> b | |
-- [[5, 6], | |
-- [7, 8]] | |
-- >>> mmult a b | |
-- [[19, 22], | |
-- [43, 50]] | |
mmult :: forall m n k a. (Num a, KnownNat m, KnownNat n, KnownNat k) => | |
Tensor '[m,k] a -> | |
Tensor '[k,n] a -> | |
Tensor '[m,n] a | |
mmult x y = tabulate (\[i,j] -> unsafeRow i x <.> unsafeCol j y) | |
-- | extract the row of a matrix | |
row | |
:: forall i a m n | |
. (KnownNat m, KnownNat n, KnownNat i, (i :< m) ~ 'True) | |
=> Proxy i | |
-> Tensor '[m,n] a | |
-> Tensor '[n] a | |
row i_ t = unsafeRow i t | |
where | |
i = (fromIntegral . fromSing . singByProxy) i_ | |
unsafeRow | |
:: forall i a m n | |
. (KnownNat m, KnownNat n) | |
=> Int -> Tensor '[m, n] a -> Tensor '[n] a | |
unsafeRow i t@(Tensor a) = Tensor $ V.unsafeSlice (i * n) n a | |
where | |
[m,n] = shape t | |
-- | extract the column of a matrix | |
col | |
:: forall j a m n | |
. (KnownNat m, KnownNat n, KnownNat j, (j :< n) ~ 'True) | |
=> Proxy j | |
-> Tensor '[m,n] a | |
-> Tensor '[m] a | |
col j_ t = unsafeCol j t | |
where | |
j = (fromIntegral . fromSing . singByProxy) j_ | |
unsafeCol :: forall a m n. (KnownNat m, KnownNat n) => | |
Int -> | |
Tensor '[m,n] a -> | |
Tensor '[m] a | |
unsafeCol j t@(Tensor a) = Tensor $ V.generate n (\x -> a V.! (j+x*m)) | |
where | |
[m,n] = shape t | |
vslice :: [[Nat]] -> [Nat] | |
vslice xs = fromIntegral . length <$> xs | |
-- | | |
-- | |
-- >>> unsafeIndex a [0,2,1] | |
-- 10 | |
unsafeIndex :: SingI r => Tensor r a -> [Int] -> a | |
unsafeIndex t@(Tensor a) i = a V.! ind (shape t) i | |
-- | | |
-- | |
-- >>> unsafeSlice [[0,1],[2],[1,2]] a :: Tensor '[2,1,2] Int | |
-- [[[10, 11]], | |
-- [[22, 23]]] | |
unsafeSlice :: SingI r => [[Int]] -> Tensor r a -> Tensor r0 a | |
unsafeSlice s t = Tensor (V.fromList [unsafeIndex t i | i <- sequence s]) | |
-- | Slice xs = Map Length xs | |
type family Slice (xss :: [[Nat]]) :: [Nat] where | |
Slice xss = Map LengthSym0 xss | |
-- | AllLT xs n = All (n >) xs | |
data AllLTSym0 (a :: TyFun [Nat] (TyFun Nat Bool -> Type)) | |
data AllLTSym1 (l :: [Nat]) (a :: TyFun Nat Bool) | |
type instance Apply AllLTSym0 l = AllLTSym1 l | |
type instance Apply (AllLTSym1 l) n = All ((:>$$) n) l | |
-- | | |
-- | |
-- >>> slice (Proxy :: Proxy '[ '[0,1],'[2],'[1,2]]) a | |
-- [[[10, 11]], | |
-- [[22, 23]]] | |
slice | |
:: forall s r a | |
. (SingI s, SingI r, And (ZipWith AllLTSym0 s r) ~ 'True) | |
=> Proxy s -> Tensor r a -> Tensor (Slice s) a | |
slice s_ t = unsafeSlice s t | |
where | |
s = ((fmap . fmap) fromInteger . fromSing . singByProxy) s_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment