{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Tensor where | |
import Data.Distributive | |
import Data.Functor.Rep | |
import Data.List | |
import Data.Maybe | |
import Data.Singletons.Prelude | |
import Data.Promotion.Prelude | |
import GHC.Exts | |
import Data.Kind | |
import GHC.Show | |
import GHC.TypeLits | |
import qualified Data.Vector as V | |
-- $setup | |
-- >>> :set -XDataKinds | |
-- >>> :set -XOverloadedLists | |
-- >>> :set -XTypeFamilies | |
-- >>> let a = [1..24] :: Tensor '[2,3,4] Int | |
-- >>> let v = [1,2,3] :: Tensor '[3] Int | |
-- | an n-dimensional array where shape is specified at the type level The main | |
-- purpose of this, beyond safe typing, is to supply the Representable instance | |
-- with an initial object. A single Boxed 'Data.Vector.Vector' is used | |
-- underneath for efficient slicing, but this may change or become polymorphic | |
-- in the future. | |
-- >>> a | |
-- [[[1, 2, 3, 4], | |
-- [5, 6, 7, 8], | |
-- [9, 10, 11, 12]], | |
-- [[13, 14, 15, 16], | |
-- [17, 18, 19, 20], | |
-- [21, 22, 23, 24]]] | |
newtype Tensor (r::[Nat]) a = Tensor { flattenTensor :: V.Vector a } | |
deriving (Functor, Eq, Foldable) | |
data SomeTensor a = SomeTensor [Int] (V.Vector a) | |
deriving (Functor, Eq, Foldable) | |
-- | convert a 'Tensor' to a 'SomeTensor', losing the type level shape | |
someTensor :: (SingI r) => Tensor (r::[Nat]) a -> SomeTensor a | |
someTensor n = SomeTensor (shape n) (flattenTensor n) | |
-- | extracts shape from the type level | |
-- >>> shape a | |
-- [2,3,4] | |
shape :: forall (r::[Nat]) a. (SingI r) => Tensor r a -> [Int] | |
shape _ = fmap fromIntegral (fromSing (sing :: Sing r)) | |
-- | convert from n-dim shape index to a flat index | |
-- >>> ind [2,3,4] [1,1,1] | |
-- 17 | |
ind :: [Int] -> [Int] -> Int | |
ind ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) 1 ns) | |
unfoldI :: forall t. Integral t => [t] -> t -> ([t], t) | |
unfoldI ns x = | |
foldr | |
(\a (acc, r) -> let (d,m) = divMod r a in (m:acc,d)) | |
([],x) | |
ns | |
-- | convert from a flat index to a shape index | |
-- >>> unind [2,3,4] 17 | |
-- [1,1,1] | |
unind :: [Int] -> Int -> [Int] | |
unind ns x= fst $ unfoldI ns x | |
instance forall r. (SingI r) => Distributive (Tensor (r::[Nat])) where | |
distribute f = Tensor $ V.generate n | |
$ \i -> fmap (\(Tensor v) -> V.unsafeIndex v i) f | |
where | |
n = case (sing :: Sing r) of | |
SNil -> 1 | |
(SCons x xs) -> product $ fromInteger <$> (fromSing x: fromSing xs) | |
instance forall (r :: [Nat]). (SingI r) => Representable (Tensor r) where | |
type Rep (Tensor r) = [Int] | |
tabulate f = Tensor $ V.generate (product ns) (f . unind ns) | |
where | |
ns = case (sing :: Sing r) of | |
SNil -> [] | |
(SCons x xs) -> fromIntegral <$> (fromSing x: fromSing xs) | |
index (Tensor xs) rs = xs V.! ind ns rs | |
where | |
ns = case (sing :: Sing r) of | |
SNil -> [] | |
(SCons x xs') -> fromIntegral <$> (fromSing x: fromSing xs') | |
-- | from flat list | |
instance (SingI r, Num a) => IsList (Tensor (r::[Nat]) a) where | |
type Item (Tensor r a) = a | |
fromList l = Tensor $ V.fromList $ take n $ l ++ repeat 0 | |
where | |
n = case (sing :: Sing r) of | |
SNil -> 1 | |
(SCons x xs') -> product $ fromIntegral <$> (fromSing x: fromSing xs') | |
toList (Tensor v) = V.toList v | |
instance (Show a) => Show (SomeTensor a) where | |
show r@(SomeTensor l _) = go (length l) r | |
where | |
go n r'@(SomeTensor l' v') = case length l' of | |
0 -> show $ V.head v' | |
1 -> "[" ++ intercalate ", " (show <$> GHC.Exts.toList v') ++ "]" | |
x -> | |
"[" ++ | |
intercalate | |
(",\n" ++ replicate (n-x+1) ' ') | |
(go n <$> flatten1 r') ++ | |
"]" | |
-- | convert the top layer of a SomeTensor to a [SomeTensor] | |
flatten1 :: SomeTensor a -> [SomeTensor a] | |
flatten1 (SomeTensor rep v) = (\s -> SomeTensor (drop 1 rep) (V.unsafeSlice (s*l) l v)) <$> ss | |
where | |
(n, l) = case rep of | |
[] -> (0, 1) | |
n : r -> (n, product r) | |
ss = take n [0..] | |
instance (Show a, SingI r) => Show (Tensor (r::[Nat]) a) where | |
show = show . someTensor | |
-- ** Operations | |
-- | inner product | |
-- >>> v <.> v | |
-- 14 | |
(<.>) :: (Num a, Foldable m, Representable m) => m a -> m a -> a | |
(<.>) a b = sum $ liftR2 (*) a b | |
-- | outer product | |
-- >>> v >< v | |
-- [[1, 2, 3], | |
-- [2, 4, 6], | |
-- [3, 6, 9]] | |
(><) | |
:: forall (r::[Nat]) (s::[Nat]) a | |
. (Num a, SingI r, SingI s, SingI (r :++ s)) | |
=> Tensor r a -> Tensor s a -> Tensor (r :++ s) a | |
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i)) | |
where | |
dimm = length (shape m) | |
-- | | |
-- | |
-- >>> let a = [1, 2, 3, 4] :: Tensor '[2, 2] Int | |
-- >>> let b = [5, 6, 7, 8] :: Tensor '[2, 2] Int | |
-- >>> a | |
-- [[1, 2], | |
-- [3, 4]] | |
-- >>> b | |
-- [[5, 6], | |
-- [7, 8]] | |
-- >>> mmult a b | |
-- [[19, 22], | |
-- [43, 50]] | |
mmult :: forall m n k a. (Num a, KnownNat m, KnownNat n, KnownNat k) => | |
Tensor '[m,k] a -> | |
Tensor '[k,n] a -> | |
Tensor '[m,n] a | |
mmult x y = tabulate (\[i,j] -> unsafeRow i x <.> unsafeCol j y) | |
-- | extract the row of a matrix | |
row | |
:: forall i a m n | |
. (KnownNat m, KnownNat n, KnownNat i, (i :< m) ~ 'True) | |
=> Proxy i | |
-> Tensor '[m,n] a | |
-> Tensor '[n] a | |
row i_ t = unsafeRow i t | |
where | |
i = (fromIntegral . fromSing . singByProxy) i_ | |
unsafeRow | |
:: forall i a m n | |
. (KnownNat m, KnownNat n) | |
=> Int -> Tensor '[m, n] a -> Tensor '[n] a | |
unsafeRow i t@(Tensor a) = Tensor $ V.unsafeSlice (i * n) n a | |
where | |
[m,n] = shape t | |
-- | extract the column of a matrix | |
col | |
:: forall j a m n | |
. (KnownNat m, KnownNat n, KnownNat j, (j :< n) ~ 'True) | |
=> Proxy j | |
-> Tensor '[m,n] a | |
-> Tensor '[m] a | |
col j_ t = unsafeCol j t | |
where | |
j = (fromIntegral . fromSing . singByProxy) j_ | |
unsafeCol :: forall a m n. (KnownNat m, KnownNat n) => | |
Int -> | |
Tensor '[m,n] a -> | |
Tensor '[m] a | |
unsafeCol j t@(Tensor a) = Tensor $ V.generate n (\x -> a V.! (j+x*m)) | |
where | |
[m,n] = shape t | |
vslice :: [[Nat]] -> [Nat] | |
vslice xs = fromIntegral . length <$> xs | |
-- | | |
-- | |
-- >>> unsafeIndex a [0,2,1] | |
-- 10 | |
unsafeIndex :: SingI r => Tensor r a -> [Int] -> a | |
unsafeIndex t@(Tensor a) i = a V.! ind (shape t) i | |
-- | | |
-- | |
-- >>> unsafeSlice [[0,1],[2],[1,2]] a :: Tensor '[2,1,2] Int | |
-- [[[10, 11]], | |
-- [[22, 23]]] | |
unsafeSlice :: SingI r => [[Int]] -> Tensor r a -> Tensor r0 a | |
unsafeSlice s t = Tensor (V.fromList [unsafeIndex t i | i <- sequence s]) | |
-- | Slice xs = Map Length xs | |
type family Slice (xss :: [[Nat]]) :: [Nat] where | |
Slice xss = Map LengthSym0 xss | |
-- | AllLT xs n = All (n >) xs | |
data AllLTSym0 (a :: TyFun [Nat] (TyFun Nat Bool -> Type)) | |
data AllLTSym1 (l :: [Nat]) (a :: TyFun Nat Bool) | |
type instance Apply AllLTSym0 l = AllLTSym1 l | |
type instance Apply (AllLTSym1 l) n = All ((:>$$) n) l | |
-- | | |
-- | |
-- >>> slice (Proxy :: Proxy '[ '[0,1],'[2],'[1,2]]) a | |
-- [[[10, 11]], | |
-- [[22, 23]]] | |
slice | |
:: forall s r a | |
. (SingI s, SingI r, And (ZipWith AllLTSym0 s r) ~ 'True) | |
=> Proxy s -> Tensor r a -> Tensor (Slice s) a | |
slice s_ t = unsafeSlice s t | |
where | |
s = ((fmap . fmap) fromInteger . fromSing . singByProxy) s_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment