Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE UndecidableInstances #-}
module Tensor where
import Data.Distributive
import Data.Functor.Rep
import Data.List
import Data.Maybe
import Data.Singletons.Prelude
import Data.Promotion.Prelude
import GHC.Exts
import Data.Kind
import GHC.Show
import GHC.TypeLits
import qualified Data.Vector as V
-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> let a = [1..24] :: Tensor '[2,3,4] Int
-- >>> let v = [1,2,3] :: Tensor '[3] Int
-- | an n-dimensional array where shape is specified at the type level The main
-- purpose of this, beyond safe typing, is to supply the Representable instance
-- with an initial object. A single Boxed 'Data.Vector.Vector' is used
-- underneath for efficient slicing, but this may change or become polymorphic
-- in the future.
-- >>> a
-- [[[1, 2, 3, 4],
-- [5, 6, 7, 8],
-- [9, 10, 11, 12]],
-- [[13, 14, 15, 16],
-- [17, 18, 19, 20],
-- [21, 22, 23, 24]]]
newtype Tensor (r::[Nat]) a = Tensor { flattenTensor :: V.Vector a }
deriving (Functor, Eq, Foldable)
data SomeTensor a = SomeTensor [Int] (V.Vector a)
deriving (Functor, Eq, Foldable)
-- | convert a 'Tensor' to a 'SomeTensor', losing the type level shape
someTensor :: (SingI r) => Tensor (r::[Nat]) a -> SomeTensor a
someTensor n = SomeTensor (shape n) (flattenTensor n)
-- | extracts shape from the type level
-- >>> shape a
-- [2,3,4]
shape :: forall (r::[Nat]) a. (SingI r) => Tensor r a -> [Int]
shape _ = fmap fromIntegral (fromSing (sing :: Sing r))
-- | convert from n-dim shape index to a flat index
-- >>> ind [2,3,4] [1,1,1]
-- 17
ind :: [Int] -> [Int] -> Int
ind ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) 1 ns)
unfoldI :: forall t. Integral t => [t] -> t -> ([t], t)
unfoldI ns x =
foldr
(\a (acc, r) -> let (d,m) = divMod r a in (m:acc,d))
([],x)
ns
-- | convert from a flat index to a shape index
-- >>> unind [2,3,4] 17
-- [1,1,1]
unind :: [Int] -> Int -> [Int]
unind ns x= fst $ unfoldI ns x
instance forall r. (SingI r) => Distributive (Tensor (r::[Nat])) where
distribute f = Tensor $ V.generate n
$ \i -> fmap (\(Tensor v) -> V.unsafeIndex v i) f
where
n = case (sing :: Sing r) of
SNil -> 1
(SCons x xs) -> product $ fromInteger <$> (fromSing x: fromSing xs)
instance forall (r :: [Nat]). (SingI r) => Representable (Tensor r) where
type Rep (Tensor r) = [Int]
tabulate f = Tensor $ V.generate (product ns) (f . unind ns)
where
ns = case (sing :: Sing r) of
SNil -> []
(SCons x xs) -> fromIntegral <$> (fromSing x: fromSing xs)
index (Tensor xs) rs = xs V.! ind ns rs
where
ns = case (sing :: Sing r) of
SNil -> []
(SCons x xs') -> fromIntegral <$> (fromSing x: fromSing xs')
-- | from flat list
instance (SingI r, Num a) => IsList (Tensor (r::[Nat]) a) where
type Item (Tensor r a) = a
fromList l = Tensor $ V.fromList $ take n $ l ++ repeat 0
where
n = case (sing :: Sing r) of
SNil -> 1
(SCons x xs') -> product $ fromIntegral <$> (fromSing x: fromSing xs')
toList (Tensor v) = V.toList v
instance (Show a) => Show (SomeTensor a) where
show r@(SomeTensor l _) = go (length l) r
where
go n r'@(SomeTensor l' v') = case length l' of
0 -> show $ V.head v'
1 -> "[" ++ intercalate ", " (show <$> GHC.Exts.toList v') ++ "]"
x ->
"[" ++
intercalate
(",\n" ++ replicate (n-x+1) ' ')
(go n <$> flatten1 r') ++
"]"
-- | convert the top layer of a SomeTensor to a [SomeTensor]
flatten1 :: SomeTensor a -> [SomeTensor a]
flatten1 (SomeTensor rep v) = (\s -> SomeTensor (drop 1 rep) (V.unsafeSlice (s*l) l v)) <$> ss
where
(n, l) = case rep of
[] -> (0, 1)
n : r -> (n, product r)
ss = take n [0..]
instance (Show a, SingI r) => Show (Tensor (r::[Nat]) a) where
show = show . someTensor
-- ** Operations
-- | inner product
-- >>> v <.> v
-- 14
(<.>) :: (Num a, Foldable m, Representable m) => m a -> m a -> a
(<.>) a b = sum $ liftR2 (*) a b
-- | outer product
-- >>> v >< v
-- [[1, 2, 3],
-- [2, 4, 6],
-- [3, 6, 9]]
(><)
:: forall (r::[Nat]) (s::[Nat]) a
. (Num a, SingI r, SingI s, SingI (r :++ s))
=> Tensor r a -> Tensor s a -> Tensor (r :++ s) a
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i))
where
dimm = length (shape m)
-- |
--
-- >>> let a = [1, 2, 3, 4] :: Tensor '[2, 2] Int
-- >>> let b = [5, 6, 7, 8] :: Tensor '[2, 2] Int
-- >>> a
-- [[1, 2],
-- [3, 4]]
-- >>> b
-- [[5, 6],
-- [7, 8]]
-- >>> mmult a b
-- [[19, 22],
-- [43, 50]]
mmult :: forall m n k a. (Num a, KnownNat m, KnownNat n, KnownNat k) =>
Tensor '[m,k] a ->
Tensor '[k,n] a ->
Tensor '[m,n] a
mmult x y = tabulate (\[i,j] -> unsafeRow i x <.> unsafeCol j y)
-- | extract the row of a matrix
row
:: forall i a m n
. (KnownNat m, KnownNat n, KnownNat i, (i :< m) ~ 'True)
=> Proxy i
-> Tensor '[m,n] a
-> Tensor '[n] a
row i_ t = unsafeRow i t
where
i = (fromIntegral . fromSing . singByProxy) i_
unsafeRow
:: forall i a m n
. (KnownNat m, KnownNat n)
=> Int -> Tensor '[m, n] a -> Tensor '[n] a
unsafeRow i t@(Tensor a) = Tensor $ V.unsafeSlice (i * n) n a
where
[m,n] = shape t
-- | extract the column of a matrix
col
:: forall j a m n
. (KnownNat m, KnownNat n, KnownNat j, (j :< n) ~ 'True)
=> Proxy j
-> Tensor '[m,n] a
-> Tensor '[m] a
col j_ t = unsafeCol j t
where
j = (fromIntegral . fromSing . singByProxy) j_
unsafeCol :: forall a m n. (KnownNat m, KnownNat n) =>
Int ->
Tensor '[m,n] a ->
Tensor '[m] a
unsafeCol j t@(Tensor a) = Tensor $ V.generate n (\x -> a V.! (j+x*m))
where
[m,n] = shape t
vslice :: [[Nat]] -> [Nat]
vslice xs = fromIntegral . length <$> xs
-- |
--
-- >>> unsafeIndex a [0,2,1]
-- 10
unsafeIndex :: SingI r => Tensor r a -> [Int] -> a
unsafeIndex t@(Tensor a) i = a V.! ind (shape t) i
-- |
--
-- >>> unsafeSlice [[0,1],[2],[1,2]] a :: Tensor '[2,1,2] Int
-- [[[10, 11]],
-- [[22, 23]]]
unsafeSlice :: SingI r => [[Int]] -> Tensor r a -> Tensor r0 a
unsafeSlice s t = Tensor (V.fromList [unsafeIndex t i | i <- sequence s])
-- | Slice xs = Map Length xs
type family Slice (xss :: [[Nat]]) :: [Nat] where
Slice xss = Map LengthSym0 xss
-- | AllLT xs n = All (n >) xs
data AllLTSym0 (a :: TyFun [Nat] (TyFun Nat Bool -> Type))
data AllLTSym1 (l :: [Nat]) (a :: TyFun Nat Bool)
type instance Apply AllLTSym0 l = AllLTSym1 l
type instance Apply (AllLTSym1 l) n = All ((:>$$) n) l
-- |
--
-- >>> slice (Proxy :: Proxy '[ '[0,1],'[2],'[1,2]]) a
-- [[[10, 11]],
-- [[22, 23]]]
slice
:: forall s r a
. (SingI s, SingI r, And (ZipWith AllLTSym0 s r) ~ 'True)
=> Proxy s -> Tensor r a -> Tensor (Slice s) a
slice s_ t = unsafeSlice s t
where
s = ((fmap . fmap) fromInteger . fromSing . singByProxy) s_  
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment