Skip to content

Instantly share code, notes, and snippets.

@JakobBruenker
Created December 4, 2018 23:07
Show Gist options
  • Save JakobBruenker/7b31d56b200c00856d15cf94e3eff899 to your computer and use it in GitHub Desktop.
Save JakobBruenker/7b31d56b200c00856d15cf94e3eff899 to your computer and use it in GitHub Desktop.
An interface to accelerate that includes matrix sizes in the types. Needs the accelerate and singletons libraries.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DataKinds #-}
-- {-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE UndecidableInstances #-}
module LinearTypesafe where
import TypesafeAccelerate
import Prelude hiding (zipWith, Num, replicate, (++))
import Data.Array.Accelerate (Num, unindex1, index2, constant)
import Data.Proxy
import Data.Singletons.Prelude
import Data.Singletons.TypeLits
-- import qualified Data.Semigroup as S
-- inner product
infixr 8 <.>
(<.>) :: Num e => Vector n e -> Vector n e -> Scalar e
v <.> u = fold (+) 0 $ zipWith (*) v u
--outer product
infixr 8 ><
(><) :: forall n m e. (KnownNat n, KnownNat m, Num e)
=> Vector n e -> Vector m e -> Matrix n m e
a >< b = (reshape a :: Matrix n 1 e) <> reshape b
-- -- outer product
-- infixr 8 ><
-- (><) :: Acc (Vector Float) -> Acc (Vector Float) -> Acc (Matrix Float)
-- a >< b = reshape (index2 (length a) 1) a <>
-- reshape (index2 1 (length b)) b
-- matrix vector product
infixr 8 #>
(#>) :: forall n m e. (Num e, KnownNat m, KnownNat n)
=> Matrix n m e -> Vector m e -> Vector n e
a #> v = fold (+) 0 $ zipWith (*) a (replicate ss v :: Matrix n m e)
where ss = Proxy :: Proxy [SN n, SAll]
-- vector matrix product
infixr 8 <#
(<#) :: forall n m e. (Num e, KnownNat m, KnownNat n)
=> Vector n e -> Matrix n m e -> Vector m e
v <# a = fold (+) 0 $ zipWith (*) (replicate ss v) $ transpose a
where ss = Proxy :: Proxy [SN m, SAll]
-- TODO make foldseg typesafe?
-- FIXME did not expect this but `td <> tc` results in "ssFromSsd: Illegal
-- combination of slice and shape"
infixr 8 <>
(<>) :: forall n m o e.
(KnownNat n, KnownNat m, KnownNat o, KnownNat (o :* m), Num e)
=> Matrix n m e -> Matrix m o e -> Matrix n o e
a <> b = foldSeg (+) 0 mul seg
-- The type signatures are not actually necessary, but nonetheless helpful
-- for documentation
where mul :: Matrix n (o :* m) e
mul = zipWith (*) repa repb
repa :: Matrix n (o :* m) e
repa = reshape $ replicate (Proxy :: Proxy [SAll, SN o, SAll]) a
repb :: Matrix n (o :* m) e
repb = replicate (Proxy :: Proxy [SN n, SAll]) . flatten $ transpose b
seg :: Vector o Int
seg = fill . constant . snd $ matrixShape a
-- -- matrix matrix product
-- -- if columns of a aren't equal in number to rows of b, the larger
-- -- array is cropped so they match
-- infixr 8 <>
-- (<>) :: Acc (Matrix Float) -> Acc (Matrix Float) -> Acc (Matrix Float)
-- a <> b = foldSeg (+) 0 mul seg
-- where (ra, ca) = let rca = unindex2 $ shape a in (fst rca, snd rca)
-- (rb, cb) = let rcb = unindex2 $ shape b in (fst rcb, snd rcb)
-- len = min rb ca
-- a' = take len a
-- b' = transpose . take len $ transpose b
-- mul = zipWith (*) repa repb
-- repb = replicate (lift $ Z :. ra :. All) . flatten $ transpose b'
-- repa = reshape (index2 ra $ len * cb) $
-- replicate (lift $ Z :. All :. cb :. All) a'
-- seg = fill (index1 cb) len :: Acc (Segments Int)
-- XXX Do we want this? If so, this is an orphan instance, so either newtype
-- it or rethink the modules
-- type Square n e = Matrix n n e
-- Not sure if these instance are useful
-- instance (KnownNat n, KnownNat (n :* n), Num e) => S.Semigroup (Square n e) where
-- (<>) = (<>)
-- instance (KnownNat n, KnownNat (n :* n), Num e) => Monoid (Square n e) where
-- mempty = identity
-- mappend = (S.<>)
-- diagonal matrix
diagonal :: (Num e, KnownNat n, KnownNat m)
=> Vector (Min n m) e -> Matrix n m e
diagonal v = permute const zeros (\(unindex1 -> i) -> index2 i i) v
-- synonym for identity
eye :: (Num e, KnownNat n, KnownNat m, KnownNat (Min n m)) => Matrix n m e
eye = identity
-- identity matrix
identity :: (Num e, KnownNat n, KnownNat m, KnownNat (Min n m))
=> Matrix n m e
identity = diagonal ones
zeros :: (Num e, KnownShape dims, ShapeLike dims) => Tensor dims e
zeros = fill 0
ones :: (Num e, KnownShape dims, ShapeLike dims) => Tensor dims e
ones = fill 1
-- TODO Right now these are defined for matrices, will have to generalize to
-- tensors
-- if the new dimension is larger by an uneven amount, the padding will be
-- larger on the upper/left side
-- zeroPad :: forall n m n' m' e. Matrix n m e -> Matrix n' m' e
-- zeroPad :: forall n m n' m' e. Matrix n m e -> Matrix n m' e
-- zeroPad a = (zeros :: Matrix n (Div2 (m' :- m)) e) ++ a ++ zeros
-- zeroPad a = (zeros :: Matrix n (Minus m' m) e) ++ a
zeroPadL :: forall n m l e. (KnownNat l, KnownNat n, Num e)
=> Matrix n m e -> Matrix n (l :+ m) e
zeroPadL a = (zeros :: Matrix n l e) ++ a
zeroPadR :: forall n m r e. (KnownNat r, KnownNat n, Num e)
=> Matrix n m e -> Matrix n (m :+ r) e
zeroPadR a = a ++ (zeros :: Matrix n r e)
zeroPadT :: forall n m t e. (KnownNat t, KnownNat m, Num e)
=> Matrix n m e -> Matrix (t :+ n) m e
zeroPadT a = transpose $ (zeros :: Matrix m t e) ++ transpose a
zeroPadB :: forall n m b e. (KnownNat b, KnownNat m, Num e)
=> Matrix n m e -> Matrix (n :+ b) m e
zeroPadB a = transpose $ transpose a ++ (zeros :: Matrix m b e)
zeroPadAll :: forall n m k e.
(KnownNat n, KnownNat k, KnownNat (k :+ m :+ k), Num e)
=> Proxy k -> Matrix n m e
-> Matrix (k :+ n :+ k) (k :+ m :+ k) e
-- -> Matrix n (k :+ m :+ k) e
zeroPadAll _ a = zeroPadB (zeroPadT (zeroPadR (zeroPadL a :: Matrix n (k :+ m) e) :: Matrix n (k :+ m :+ k) e) :: Matrix (k :+ n) (k :+ m :+ k) e) :: Matrix (k :+ n :+ k) (k :+ m :+ k) e
-- TODO
-- im2col
-- XXX padding?
-- XXX using generate appears to be a very bad implementation, since you
-- probably cannot parallelize it
-- conv2d :: Matrix n m e -> Matrix h w e -> Matrix n m e
-- conv2d = reshape $ kernelMat <> imMat
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE PolyKinds #-}
-- {-# LANGUAGE TypeInType #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module TypesafeAccelerate where
import Prelude (Show, ($), show, (<$>), error, otherwise)
import qualified Prelude as P
import Data.Kind
import Data.Singletons
import Data.Singletons.Prelude hiding ((:.), Reverse)
import Data.Singletons.Prelude.Enum
import Data.Singletons.TypeLits
import GHC.TypeLits
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Array, Acc, Exp, Double, DIM2, Elt, Z(..), Shape, Int, Bool(..), (:.)(..), Num, (.), (+), fromList, use, constant, fromInteger, IsIntegral, FromIntegral, Slice, (?|), Arrays, const, SliceShape, FullShape, Ord)
-- TODO To work with -XRebindableSyntax, this module has to export these
-- functions, or at the very least the ones that aren't identical to their
-- Prelude versions:
-- - fromInteger (identical to Prelude)
-- - fromRational (identical to Prelude)
-- - (==)
-- - (-) (identical to Prelude)
-- - (>=)
-- - negate (identical to Prelude)
-- - ifThenElse
-- TODO think about (re-)exports in general
data Slicer :: Type where
SN :: (n :: Type) -> Slicer
SAll :: Slicer
SAny :: Slicer
data instance Sing (a :: Slicer) where
SSN :: KnownNat n => Sing n -> Sing (SN n)
SSAll :: Sing SAll
SSAny :: Sing SAny
type SSlicer (a :: Slicer) = Sing a
instance SingI SAny where
sing = SSAny
instance SingI SAll where
sing = SSAll
-- instance SingI n => SingI (SN n) where
instance KnownNat n => SingI (SN n) where
sing = SSN sing
data SliceMode = Sliced | Replicated
type family IsReplicated (mode :: SliceMode) where
IsReplicated Replicated = True
IsReplicated Sliced = False
type Slicing ss dims = Chopping Sliced ss dims
type Replicating ss dims = Chopping Replicated ss dims
type SlicedShape ss dims = ChoppedShape Sliced ss dims
type ReplicatedShape ss dims = ChoppedShape Replicated ss dims
type ShapeLike dims = Shape (ShapeOf dims)
-- TODO: add nice errors, but be careful because cutting is also used in
-- ChoppedShaped
-- For any Slicer
type family Chopping (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
:: Constraint
where
Chopping _ '[] '[] = ()
Chopping mode (SAny : ss) ds = (Cleaving mode ss ds ~ True)
Chopping mode ss ds = (Cutting mode ss ds ~ True)
-- For Slicers not containing SAny
type family Cutting (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
where
Cutting _ '[] '[] = True
-- Cutting mode (SN n : ss) (d : ds) =
-- (IsReplicated mode :|| n :< d) :&& Cutting mode ss ds
Cutting Sliced (SN n : ss) (d : ds) = n :< d :&& Cutting Sliced ss ds
Cutting Replicated (SN n : ss) ds = Cutting Replicated ss ds
Cutting mode (SAll : ss) (_ : ds) = Cutting mode ss ds
Cutting _ _ _ = False
-- For Slicers containing Any
type family Cleaving (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
where
Cleaving _ '[] _ = True
Cleaving mode ss '[] = Cutting mode ss '[]
-- Cleaving mode (SN n : ss) (d : ds) =
-- (IsReplicated mode :|| n :< d) :&& Cutting mode ss ds :||
-- Cleaving mode (SN n : ss) ds
Cleaving Sliced (SN n : ss) (d : ds) =
n :< d :&& Cutting Sliced ss ds :|| Cleaving Sliced (SN n : ss) ds
Cleaving Replicated (SN n : ss) (d : ds) =
Cutting Replicated ss (d : ds) :|| Cleaving Replicated (SN n : ss) ds
Cleaving mode (SAll : ss) (_ : ds) =
Cutting mode ss ds :|| Cleaving mode (SAll : ss) ds
Cleaving _ _ _ = False
-- For any Slicer
-- NB: it might make sense to add a catch-all pattern, but it doesn't seem
-- strictly necessary, especially considering there really *isn't* a valid
-- 'ChoppedShape' if none of these patterns match
-- TODO Probably add a type error in that place though
type family ChoppedShape
(mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
where
ChoppedShape mode '[] '[] = '[]
ChoppedShape mode (SAny : ss) ds = CleftShape mode ss ds
ChoppedShape mode ss ds = CutShape mode ss ds
type family CutShape (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
where
CutShape _ '[] '[] = '[]
CutShape Sliced (SN _ : ss) (_ : ds) = CutShape Sliced ss ds
-- CutShape Replicated (SN n : ss) (d : ds) =
-- (n :* d) : CutShape Replicated ss ds
CutShape Replicated (SN n : ss) ds =
n : CutShape Replicated ss ds
CutShape mode (SAll : ss) (d : ds) = d : CutShape mode ss ds
type family CleftShape
(mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat])
where
CleftShape _ '[] _ = '[]
-- TODO is this for Sliced too or just for replicated
CleftShape mode ss '[] = CutShape mode ss '[]
CleftShape mode (SN n : ss) (d : ds) =
If (Cutting mode (SN n : ss) (d : ds))
(CutShape mode (SN n : ss) (d : ds))
(d : CleftShape mode (SN n : ss) ds)
CleftShape mode (SAll : ss) (d : ds) =
If (Cutting mode (SAll : ss) (d : ds))
(CutShape mode (SAll : ss) (d : ds))
(d : CleftShape mode (SAll : ss) ds)
-- type level reverse
-- this is used instead of the version provided by singletons so that in type
-- signatures, it shows up as "Reverse xs", and not as something like
-- "Data.Singletons.Prelude.List.Let6989586621679748992Rev xs xs '[]"
-- Similarly, case splitting prevents ghc from applying function on unknown
-- list
type family Reverse (xs :: [a]) = (rs :: [a]) where
Reverse '[] = '[]
Reverse xs = Rev xs '[]
type family Rev (xs :: [a]) (acc :: [a]) = (rs :: [a]) where
Rev '[] acc = acc
Rev (x:xs) acc = Rev xs (x:acc)
-- XXX maybe have this be ATensor or AccTensor and introduce second type
-- Tensor, which has Arrays that are not in Acc
data Tensor :: [Nat] -> Type -> Type where
Tensor :: Acc (Array (ShapeOf dims) e) -> Tensor dims e
type Scalar = Tensor '[]
type Vector n = Tensor '[n]
type Matrix n m = Tensor '[n, m]
-- We have to use this instead of record syntax because, because here we
-- can specify the type
unTensor :: Tensor dims e -> Acc (Array (ShapeOf dims) e)
unTensor (Tensor t) = t
-- | 'unsafeMakeTensor' unsafely creates a Tensor from an Accelerate array,
-- but does not check whether the size of the array match the dimensions
-- in the type of the 'Tensor'. It does check whether the dimensionality of
-- the size is identical.
unsafeMakeTensor :: Acc (Array (ShapeOf dims) e) -> Tensor dims e
unsafeMakeTensor = Tensor
-- | 'unsafeUseArray' creates a Tensor from an accelerate array. If the size
-- is incorrect, it will throw an error at runtime
unsafeUseArray :: forall dims e sh.
(KnownShape dims, sh ~ ShapeOf dims,
Shape sh, Elt e, P.Eq sh)
=> Array (ShapeOf dims) e -> Tensor dims e
unsafeUseArray array
| arrayShape P.== tensorShape = Tensor (A.use array)
| otherwise = error $ "Couldn't match expected shape " P.++
show tensorShape P.++ " with actual shape " P.++ show arrayShape
where arrayShape = A.arrayShape array
tensorShape = shFromDims (Proxy :: Proxy dims)
type family IntShape (dims :: [Nat]) where
IntShape '[] = Z
IntShape (_:ds) = IntShape ds :. Int
type ShapeOf (dims :: [Nat]) = IntShape (Reverse dims)
type KnownShape dims = SingI (Reverse dims)
type KnownSlicer slcr = SingI (Reverse slcr)
-- IntShapeData is a type equivalent to the following data family, if it were
-- closed:
--
-- data family IntShapeData (dims :: [Nat])
-- data instance IntShapeData '[] = RISDNil Z
-- data instance IntShapeData (_:ds) = RISDCons (IntShapeData ds :. Int)
--
-- This makes the similarities to the IntShape type family more obvious.
-- The reason this is necessary is because IntShape isn't injective, but this
-- injectivity is required for shFromDims.
data IntShapeData :: [Nat] -> Type where
RISDNil :: Z -> IntShapeData '[]
RISDCons :: (IntShapeData ds :. Int) -> IntShapeData (d:ds)
type family SlicerShape (slcr :: [Slicer]) (dims :: [Nat]) where
SlicerShape '[] _ = Z
SlicerShape '[SAny] ds = A.Any (IntShape ds)
SlicerShape (SN _ : ss) ds = SlicerShape ss ds :. Int
SlicerShape (SAll : ss) (_ : ds) = SlicerShape ss ds :. A.All
-- TODO convert to GADT
data family SlicerShapeData (slcr :: [Slicer]) (dims :: [Nat])
data instance SlicerShapeData '[] _ = SSDNil Z
data instance SlicerShapeData '[SAny] ds = SSDAny (A.Any (IntShape ds))
data instance SlicerShapeData (SN _ : ss) ds =
SSDConsN (SlicerShapeData ss ds :. Int)
data instance SlicerShapeData (SAll : ss) (_ : ds) =
SSDConsAll (SlicerShapeData ss ds :. A.All)
instance Show (SlicerShapeData '[] ns) where
show (SSDNil Z) = "SSDNil Z"
instance Show (SlicerShapeData '[SAny] ns) where
show (SSDAny A.Any) = "SSDAny A.Any"
instance (Show (SlicerShapeData ss ns))
=> Show (SlicerShapeData (SAll : ss) (n : ns)) where
show (SSDConsAll (ssd :. A.All)) = "SSDConsAll (" P.++ show ssd P.++
" :. A.All)"
instance (Show (SlicerShapeData ss ns)) => Show (SlicerShapeData (SN n : ss) ns) where
show (SSDConsN (ssd :. d)) = "SSDConsN (" P.++
show ssd P.++ " :. " P.++ show d P.++ ")"
type SliceOf (slcr :: [Slicer]) (dims :: [Nat]) =
SlicerShape (Reverse slcr) (Reverse dims)
isdFromDims :: forall dims. KnownShape dims
=> Proxy dims -> IntShapeData (Reverse dims)
isdFromDims _ = case sing :: SList (Reverse dims) of
SNil -> RISDNil Z
SCons (fromSing -> d) (singInstance -> SingInstance) ->
RISDCons (risdFromDims :. fromInteger d)
risdFromDims :: forall dims. SingI dims => IntShapeData dims
risdFromDims = case sing :: SList dims of
SNil -> RISDNil Z
SCons (fromSing -> d) (singInstance -> SingInstance) ->
RISDCons (risdFromDims :. fromInteger d)
isFromIsd :: forall dims. SingI dims
=> IntShapeData dims -> IntShape dims
isFromIsd shd = case (sing :: SList dims, shd) of
(SNil, RISDNil Z) -> Z
(SCons _ (singInstance -> SingInstance),
RISDCons (risd :. d)) ->
isFromIsd risd :. d
shFromDims :: KnownShape dims => Proxy dims -> ShapeOf dims
shFromDims = isFromIsd . isdFromDims
ssFromSsd :: forall dims slcr. (SingI dims, SingI slcr)
=> SlicerShapeData slcr dims -> SlicerShape slcr dims
ssFromSsd sld =
case (sing :: SList slcr, sing :: SList dims, sld) of
(SNil, _, SSDNil Z) -> Z
(SCons SSAny SNil, _, SSDAny A.Any) -> A.Any
(SCons SSAll (singInstance -> SingInstance),
SCons _ (singInstance -> SingInstance),
SSDConsAll (ssd :. A.All)) ->
ssFromSsd ssd :. A.All
(SCons (SSN _) (singInstance -> SingInstance), _, SSDConsN (rssd :. d)) ->
ssFromSsd rssd :. d
-- This should never happen due to the 'Chopping' constraint in
-- 'shFromSlicer'
_ -> error "LinearTypesafe.ssFromSsd: Illegal combination of slice \
\and shape"
rssdFromSlicer :: forall slcr dims. (KnownSlicer slcr, KnownShape dims)
=> Proxy slcr -> Proxy dims
-> SlicerShapeData (Reverse slcr) (Reverse dims)
rssdFromSlicer _ _ =
case (sing :: SList (Reverse slcr), sing :: SList (Reverse dims)) of
(SNil, SNil) -> SSDNil Z
(SCons SSAny SNil, _) -> SSDAny A.Any
(SCons SSAll (singInstance -> SingInstance),
SCons _ (singInstance -> SingInstance)) ->
SSDConsAll (ssdFromSlicer Proxy Proxy :. A.All)
(SCons (SSN (fromSing -> d)) (singInstance -> SingInstance), _) ->
SSDConsN (ssdFromSlicer Proxy Proxy :. fromInteger d)
-- This should never happen due to the 'Chopping' constraint in
-- 'shFromSlicer'
_ -> error "LinearTypesafe.rssdFromSlicer: Illegal combination of \
\slice and shape"
ssdFromSlicer :: forall slcr dims. (SingI slcr, SingI dims)
=> Proxy slcr -> Proxy dims -> SlicerShapeData slcr dims
ssdFromSlicer _ _ =
case (sing :: SList slcr, sing :: SList dims) of
(SNil, _) -> SSDNil Z
(SCons SSAny SNil, _) -> SSDAny A.Any
(SCons SSAll (singInstance -> SingInstance),
SCons _ (singInstance -> SingInstance)) ->
SSDConsAll (ssdFromSlicer Proxy Proxy :. A.All)
(SCons (SSN (fromSing -> d)) (singInstance -> SingInstance), _) ->
SSDConsN (ssdFromSlicer Proxy Proxy :. fromInteger d)
-- This should never happen due to the 'Chopping' constraint in
-- 'shFromSlicer'
_ -> error "LinearTypesafe.ssdFromSlicer: Illegal combination of \
\slice and shape"
shFromSlicer :: (KnownShape dims, KnownSlicer slcr,
Chopping mode slcr dims)
=> Proxy mode -> Proxy slcr -> Proxy dims -> SliceOf slcr dims
shFromSlicer _ = (ssFromSsd .) . rssdFromSlicer
instance (Show e, SingI dims, ShapeLike dims, Elt e)
=> Show (Tensor dims e) where
show t@(Tensor a) =
"Tensor " P.++ show (shapeInts t) P.++ " (" P.++ show a P.++ ")"
type family Length (list :: [a]) :: Nat where
Length '[] = 0
Length (_:xs) = Succ (Length xs)
type family AllLess (idx :: [Nat]) (dims :: [Nat]) :: Constraint where
AllLess is ds = If (Length is :== Length ds) (AllLess' is ds is ds)
(TypeError (ShowType is :<>: Text " and " :<>: ShowType ds :<>:
Text " must have the same length, but" :$$:
ShowType is :<>: Text " has length " :<>: ShowType (Length is) :$$:
ShowType ds :<>: Text " has length " :<>: ShowType (Length ds)))
type family AllLess' (idx :: [Nat]) (dims :: [Nat]) idx' dims' :: Constraint
where
AllLess' '[] '[] _ _ = ()
AllLess' (i:is) (d:ds) is' ds' = If (i :< d) (AllLess' is ds is' ds')
(TypeError
(ShowType i :<>: Text " is not less than " :<>: ShowType d :$$:
Text "While comparing " :<>: ShowType is' :<>:
Text " and " :<>: ShowType ds'))
AllLess' is ds is' ds' = TypeError (Text "Unexpected case in AllLess'")
type family Product (ns :: [Nat]) :: Nat where
Product '[] = 1
Product (n:ns) = n :* Product ns
(!) :: (ShapeOf dims ~ ShapeOf idx, ShapeLike idx,
AllLess idx dims, KnownShape idx, Elt e, Elt (ShapeOf idx))
=> Tensor dims e -> Proxy idx -> Exp e
Tensor t ! p = t A.! (constant $ shFromDims p)
unsafeIndex :: (Elt e, ShapeLike dims)
=> Tensor dims e -> Exp (ShapeOf dims) -> Exp e
unsafeIndex (Tensor t) = (t A.!)
(!!) :: forall idx dims e.
(SingI idx, (Product dims :> idx) ~ True,
ShapeLike dims, Elt e)
=> Tensor dims e -> Proxy idx -> Exp e
Tensor t !! _ = t A.!! (constant . fromInteger $ fromSing (sing :: SNat idx))
unsafeLinearIndex :: (ShapeLike dims, Elt e)
=> Tensor dims e -> Exp Int -> Exp e
unsafeLinearIndex (Tensor t) = (t A.!!)
the :: Elt e => Scalar e -> Exp e
the (Tensor t) = A.the t
null :: forall dims e. SingI (Product dims) => Tensor dims e -> Bool
null _ = 0 P.== fromSing (sing :: SNat (Product dims))
length :: SingI n => Vector n e -> Int
length = size
shape :: forall dims e. KnownShape dims => Tensor dims e -> ShapeOf dims
shape _ = shFromDims (Proxy :: Proxy dims)
shapeInts :: forall dims e. SingI dims => Tensor dims e -> [Int]
shapeInts _ = fromInteger <$> fromSing (sing :: SList dims)
-- this is provided so total functions can be used to index into the pair
matrixShape :: forall n m e. (SingI n, SingI m) => Matrix n m e -> (Int, Int)
matrixShape _ = (toInt (sing :: SNat n), toInt (sing :: SNat m))
where toInt = fromInteger . fromSing
size :: forall dims e. SingI (Product dims) => Tensor dims e -> Int
size _ = shapeSize (Proxy :: Proxy dims)
shapeSize :: forall dims. SingI (Product dims) => Proxy dims -> Int
shapeSize _ = fromInteger $ fromSing (sing :: SNat (Product dims))
-- XXX use? not sure how to handle this
unit :: Elt e => Exp e -> Scalar e
unit = Tensor . A.unit
generate :: forall dims e sh.
(sh ~ ShapeOf dims, Shape sh, Elt e, KnownShape dims)
=> (Exp sh -> Exp e) -> Tensor dims e
generate f =
Tensor $ A.generate (constant $ shFromDims (Proxy :: Proxy dims)) f
fill :: forall dims e. (ShapeLike dims, KnownShape dims, Elt e)
=> Exp e -> Tensor dims e
fill x = generate (const x)
enumFromN :: forall dims e.
(ShapeLike dims, KnownShape dims,
FromIntegral Int e, Num e)
=> Exp e -> Tensor dims e
enumFromN = Tensor . A.enumFromN (constant $ shFromDims (Proxy :: Proxy dims))
enumFromStepN :: forall dims e.
(ShapeLike dims, KnownShape dims,
FromIntegral Int e, Num e)
=> Exp e -> Exp e -> Tensor dims e
enumFromStepN =
(Tensor .) . A.enumFromStepN (constant $ shFromDims (Proxy :: Proxy dims))
infixr 5 ++
(++) :: (ShapeOf das ~ (sh :. Int), ShapeOf dbs ~ (sh :. Int),
ShapeOf dcs ~ (sh :. Int), Init das ~ Init dbs,
dcs ~ (Init das :++ '[Last das :+ Last dbs]),
Slice sh, Shape sh, Elt e)
=> Tensor das e -> Tensor dbs e -> Tensor dcs e
Tensor s ++ Tensor t = Tensor $ s A.++ t
-- TODO: once using accelerate 1.2, add concatOn and the other lens
-- functions
class IfThenElse bool a where
ifThenElse :: bool -> a -> a -> a
instance IfThenElse Bool a where
ifThenElse True a _ = a
ifThenElse False _ b = b
instance (ShapeLike dims, Elt a)
=> IfThenElse (Exp Bool) (Tensor dims a) where
ifThenElse bool (Tensor s) (Tensor t) = Tensor (bool ?| (s, t))
instance Arrays a => IfThenElse (Exp Bool) (Acc a) where
ifThenElse bool a b = bool ?| (a, b)
-- XXX THESE ARE ACTUALLY UNSAFE! (just the lifts, the unlifts are safe. I
-- think. It's a bit strange because once you've done it, you can feed in any
-- array, even if the size is incorrect. But that doesn't really mean the
-- function itself is unsafe, I think?, Anyway, this could be 'fixed' by
-- adding KnownShape constraints to unlift, and then throwing an error if the
-- size is incorrect. I'm just not sure if it's a good idea to do that.)
-- You can give the result any dimension drs if you use these, so be careful
liftT :: (ShapeLike das, ShapeLike drs)
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf drs) r))
-> Tensor das a -> Tensor drs r
liftT f = \(Tensor t) -> Tensor (f t)
liftT2 :: (ShapeLike das, ShapeLike dbs, ShapeLike drs)
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b)
-> Acc (Array (ShapeOf drs) r))
-> Tensor das a -> Tensor dbs b -> Tensor drs r
liftT2 f = \(Tensor s) (Tensor t) -> Tensor (f s t)
liftT3 :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, ShapeLike drs)
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b)
-> Acc (Array (ShapeOf dcs) c) -> Acc (Array (ShapeOf drs) r))
-> Tensor das a -> Tensor dbs b -> Tensor dcs c -> Tensor drs r
liftT3 f = \(Tensor s) (Tensor t) (Tensor u) -> Tensor (f s t u)
unliftT :: (ShapeLike das, ShapeLike drs, Elt a, Elt r)
=> (Tensor das a -> Tensor drs r)
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf drs) r)
unliftT f = \a -> case f (Tensor a) of Tensor t -> t
unliftT2 :: (ShapeLike das, ShapeLike dbs, ShapeLike drs, Elt a, Elt b, Elt r)
=> (Tensor das a -> Tensor dbs b -> Tensor drs r)
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b)
-> Acc (Array (ShapeOf drs) r)
unliftT2 f = \a b -> case f (Tensor a) (Tensor b) of Tensor t -> t
unliftT3 :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, ShapeLike drs,
Elt a, Elt b, Elt c, Elt r)
=> (Tensor das a -> Tensor dbs b -> Tensor dcs c -> Tensor drs r)
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b)
-> Acc (Array (ShapeOf dcs) c) -> Acc (Array (ShapeOf drs) r)
unliftT3 f = \a b c -> case f (Tensor a) (Tensor b) (Tensor c) of
Tensor t -> t
(>->) :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, Elt a, Elt b, Elt c)
=> (Tensor das a -> Tensor dbs b) -> (Tensor dbs b -> Tensor dcs c)
-> Tensor das a -> Tensor dcs c
f >-> g = \(Tensor t) -> Tensor $ (unliftT f A.>-> unliftT g) t
compute :: (ShapeLike dims, Elt e) => Tensor dims e -> Tensor dims e
compute (Tensor t) = Tensor (A.compute t)
indexed :: (ShapeLike dims, Elt e)
=> Tensor dims e -> Tensor dims (ShapeOf dims, e)
indexed (Tensor t) = Tensor (A.indexed t)
map :: (Elt b, ShapeLike dims, Elt a)
=> (Exp a -> Exp b) -> Tensor dims a -> Tensor dims b
map = liftT . A.map
imap :: (sh ~ ShapeOf dims, Elt a, Elt b, Shape sh)
=> (Exp sh -> Exp a -> Exp b) -> Tensor dims a -> Tensor dims b
imap = liftT . A.imap
zipWith :: (Elt a, Elt b, Elt r, ShapeLike dims)
=> (Exp a -> Exp b -> Exp r)
-> Tensor dims a -> Tensor dims b -> Tensor dims r
-- Technically we could use this simpler version, because we don't need to
-- intersect the shapes, but for now, I will stick with simply wrapping the
-- Accelerate functions
-- zipWith f (Tensor s) (Tensor t) = Tensor $
-- generate (shape s) (\idx -> f (s!idx) (t!idx))
zipWith = liftT2 . A.zipWith
zipWith3 :: (Elt a, Elt b, Elt c, Elt r, ShapeLike dims)
=> (Exp a -> Exp b -> Exp c -> Exp r)
-> Tensor dims a -> Tensor dims b -> Tensor dims c -> Tensor dims r
zipWith3 = liftT3 . A.zipWith3
izipWith :: (sh ~ ShapeOf dims, Shape sh, Elt a, Elt b, Elt c, Elt r)
=> (Exp sh -> Exp a -> Exp b -> Exp r)
-> Tensor dims a -> Tensor dims b -> Tensor dims r
izipWith = liftT2 . A.izipWith
izipWith3 :: (sh ~ ShapeOf dims, Shape sh, Elt a, Elt b, Elt c, Elt r)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp r)
-> Tensor dims a -> Tensor dims b -> Tensor dims c -> Tensor dims r
izipWith3 = liftT3 . A.izipWith3
zip :: (ShapeLike dims, Elt a, Elt b)
=> Tensor dims a -> Tensor dims b -> Tensor dims (a, b)
zip = liftT2 A.zip
zip3 :: (ShapeLike dims, Elt a, Elt b, Elt c)
=> Tensor dims a -> Tensor dims b -> Tensor dims c
-> Tensor dims (a, b, c)
zip3 = liftT3 A.zip3
unzip :: (ShapeLike dims, (Elt a, Elt b))
=> Tensor dims (a, b)
-> (Tensor dims a, Tensor dims b)
unzip (Tensor (A.unzip -> (s, t))) = (Tensor s, Tensor t)
unzip3 :: (ShapeLike dims, Elt a, Elt b, Elt c)
=> Tensor dims (a, b, c)
-> (Tensor dims a, Tensor dims b, Tensor dims c)
unzip3 (Tensor (A.unzip3 -> (s, t, u))) = (Tensor s, Tensor t, Tensor u)
reshape :: forall dims dims' e.
(Product dims ~ Product dims', ShapeLike dims, ShapeLike dims',
KnownShape dims', Elt e)
=> Tensor dims e -> Tensor dims' e
reshape = liftT $ A.reshape (constant $ shFromDims (Proxy :: Proxy dims'))
flatten :: (ShapeLike dims, Elt e) => Tensor dims e -> Vector (Product dims) e
flatten = liftT A.flatten
-- XXX it *might* be possible to get rid of the Proxy here, although I doubt it
replicate :: forall sl sh ss dims e.
(sl ~ SliceOf ss dims, sh ~ ShapeOf (ReplicatedShape ss dims),
ShapeOf dims ~ SliceShape sl, sh ~ FullShape sl,
KnownShape dims, KnownSlicer ss, Replicating ss dims,
Shape sh, Slice sl, Elt e)
=> Proxy ss -> Tensor dims e -> Tensor (ReplicatedShape ss dims) e
replicate _ = liftT $ A.replicate (constant $ shFromSlicer
(Proxy :: Proxy Replicated)
(Proxy :: Proxy ss)
(Proxy :: Proxy dims))
slice :: forall sl ss dims dims' e.
(sl ~ SliceOf ss dims, dims' ~ SlicedShape ss dims,
SliceShape sl ~ ShapeOf dims', ShapeOf dims ~ FullShape sl,
KnownShape dims, KnownSlicer ss, Slicing ss dims,
Slice sl, Elt e)
=> Tensor dims e -> Proxy ss -> Tensor dims' e
slice (Tensor t) _ = Tensor $ A.slice t (constant $ shFromSlicer
(Proxy :: Proxy Sliced)
(Proxy :: Proxy ss)
(Proxy :: Proxy dims))
-- XXX Should we explicitly make sure that Last dims is at least 1? Probably
-- not strictly necessary because it won't typecheck if that's not the case
-- anyway.
init :: (dims' ~ (Init dims :++ '[Pred (Last dims)]),
ShapeOf dims' ~ (sh :. Int),
ShapeOf dims ~ (sh :. Int), Slice sh, Shape sh, Elt e)
=> Tensor dims e -> Tensor dims' e
init = liftT A.init
tail :: (dims' ~ (Init dims :++ '[Pred (Last dims)]),
ShapeOf dims' ~ (sh :. Int),
ShapeOf dims ~ (sh :. Int), Slice sh, Shape sh, Elt e)
=> Tensor dims e -> Tensor dims' e
tail = liftT A.tail
-- TODO better type error if n is too high (also for drop)
take :: forall n dims dims' e sh.
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int),
dims' ~ (Init dims :++ '[n]), (n :<= Last dims) ~ True,
Slice sh, Shape sh, SingI n, Elt e)
=> Tensor dims e -> Tensor dims' e
take = liftT $ A.take (constant . fromInteger $ fromSing (sing :: SNat n))
drop :: forall n dims dims' e sh.
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int),
dims' ~ (Init dims :++ '[n]), (n :<= Last dims) ~ True,
Slice sh, Shape sh, SingI n, Elt e)
=> Tensor dims e -> Tensor dims' e
drop = liftT $ A.drop (constant . fromInteger $ fromSing (sing :: SNat n))
slit :: forall idx len dims dims' e sh.
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int),
dims' ~ (Init dims :++ '[len]), (idx + len :<= Last dims) ~ True,
Slice sh, Shape sh, SingI idx, SingI len, Elt e)
=> Proxy idx -> Tensor dims e -> Tensor dims' e
slit _ = liftT $ A.slit (constant . fromInteger $ fromSing (sing :: SNat idx))
(constant . fromInteger $ fromSing (sing :: SNat len))
permute :: (sh ~ ShapeOf dims, sh' ~ ShapeOf dims',
Shape sh, Shape sh', Elt e)
=> (Exp e -> Exp e -> Exp e) -> Tensor dims' e -> (Exp sh -> Exp sh')
-> Tensor dims e -> Tensor dims' e
permute f (Tensor def) g = liftT $ A.permute f def g
scatter :: Elt e
=> Vector dst Int -> Vector def e -> Vector src e -> Vector def e
scatter = liftT3 A.scatter
backpermute :: forall dims dims' sh' sh e.
(sh ~ ShapeOf dims, sh' ~ ShapeOf dims',
KnownShape dims, KnownShape dims', Shape sh', Elt e, Shape sh)
=> (Exp (ShapeOf dims') -> Exp (ShapeOf dims))
-> Tensor dims e -> Tensor dims' e
backpermute f = liftT $ A.backpermute (constant $ shFromDims (Proxy :: Proxy dims')) f
gather :: (ShapeLike idx, Elt e)
=> Tensor idx Int -> Vector src e -> Tensor idx e
gather = liftT2 A.gather
reverse :: Elt e => Vector n e -> Vector n e
reverse = liftT A.reverse
transpose :: Elt e => Matrix n m e -> Matrix m n e
transpose = liftT A.transpose
-- XXX use unlift to bring Acc of tuple into tuple of Accs
-- TODO If we actually want to return a Vector of results here, we don't know
-- the length at runtime, which means we need to have SomeVector to store the
-- result
-- filter :: (Exp e -> Exp Bool)
-- -> Tensor dims e -> (Vector e, Tensor (Init dims) Int)
-- filter f (Tensor t) = (filtered, nums)
-- where (filtered, nums) = A.unlift (A.filter f t)
fold :: (dims' ~ Init dims, sh ~ ShapeOf dims, sh' ~ ShapeOf dims',
sh ~ (sh' :. Int), Shape sh, Shape sh', Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e
-> Tensor dims e -> Tensor (Init dims) e
fold f x = liftT $ A.fold f x
-- TODO nice error message for empty Tensor
fold1 :: (dims' ~ Init dims, sh ~ ShapeOf dims, sh' ~ ShapeOf dims',
sh ~ (sh' :. Int), (Last dims :> 0) ~ True,
Shape sh, Shape sh', Elt e)
=> (Exp e -> Exp e -> Exp e)
-> Tensor dims e -> Tensor (Init dims) e
fold1 f = liftT $ A.fold1 f
foldAll :: (ShapeLike dims, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Scalar e
foldAll f x = liftT $ A.foldAll f x
-- TODO nice error message for empty Tensor. Also for scans.
fold1All :: (ShapeLike dims, Elt e, (Product dims :> 0) ~ True)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Scalar e
fold1All f x = liftT $ A.foldAll f x
-- TODO The sum of elements of the Vector must be less than the last
-- dimension of dims. Is this something we should/could do on the type level?
-- Also for scans.
foldSeg :: (sh ~ ShapeOf dims, dims' ~ (Init dims :++ '[n]),
sh' ~ ShapeOf dims', sh ~ (ish :. Int), sh' ~ (ish :. Int),
IsIntegral i, Elt e, Elt i, Shape ish)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Vector n i
-> Tensor dims' e
foldSeg f x = liftT2 $ A.foldSeg f x
-- TODO think about what the restriction here should be to make sure every
-- segment is non-empty. Also for scans.
fold1Seg :: (sh ~ ShapeOf dims, dims' ~ (Init dims :++ '[n]),
sh' ~ ShapeOf dims', sh ~ (ish :. Int), sh' ~ (ish :. Int),
IsIntegral i, Elt e, Elt i, Shape ish)
=> (Exp e -> Exp e -> Exp e) -> Tensor dims e -> Vector n i
-> Tensor dims' e
fold1Seg f = liftT2 $ A.fold1Seg f
all :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Elt e)
=> (Exp e -> Exp Bool) -> Tensor dims e -> Tensor (Init dims) Bool
all f = liftT $ A.all f
any :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Elt e)
=> (Exp e -> Exp Bool) -> Tensor dims e -> Tensor (Init dims) Bool
any f = liftT $ A.any f
and :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh')
=> Tensor dims Bool -> Tensor (Init dims) Bool
and = liftT $ A.and
or :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh')
=> Tensor dims Bool -> Tensor (Init dims) Bool
or = liftT $ A.or
sum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Num e)
=> Tensor dims e -> Tensor (Init dims) e
sum = liftT $ A.sum
product :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Num e)
=> Tensor dims e -> Tensor (Init dims) e
product = liftT $ A.product
minimum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Ord e)
=> Tensor dims e -> Tensor (Init dims) e
minimum = liftT $ A.minimum
maximum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int),
Shape sh', Ord e)
=> Tensor dims e -> Tensor (Init dims) e
maximum = liftT $ A.maximum
scanl :: (dims' ~ (Init dims :++ '[Last dims + 1]),
ShapeOf dims ~ (sh :. Int), ShapeOf dims' ~ (sh :. Int),
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Tensor dims' e
scanl f x = liftT $ A.scanl f x
scanl1 :: (ShapeOf dims ~ (sh :. Int), (Last dims :> 0) ~ True,
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e)
-> Tensor dims e -> Tensor dims e
scanl1 f = liftT $ A.scanl1 f
scanl' :: (sh ~ ShapeOf (Init dims), ShapeOf dims ~ (sh :. Int),
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> (Tensor dims e, Tensor (Init dims) e)
scanl' f x (Tensor t) = (Tensor scanned, Tensor folded)
where (scanned, folded) = A.unlift (A.scanl' f x t)
scanr :: (dims' ~ (Init dims :++ '[Last dims + 1]),
ShapeOf dims ~ (sh :. Int), ShapeOf dims' ~ (sh :. Int),
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Tensor dims' e
scanr f x = liftT $ A.scanr f x
scanr1 :: (ShapeOf dims ~ (sh :. Int), (Last dims :> 0) ~ True,
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e)
-> Tensor dims e -> Tensor dims e
scanr1 f = liftT $ A.scanr1 f
scanr' :: (sh ~ ShapeOf (Init dims), ShapeOf dims ~ (sh :. Int),
Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> (Tensor dims e, Tensor (Init dims) e)
scanr' f x (Tensor t) = (Tensor scanned, Tensor folded)
where (scanned, folded) = A.unlift (A.scanr' f x t)
prescanl :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> Tensor dims e
prescanl f x = liftT $ A.prescanl f x
postscanl :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> Tensor dims e
postscanl f x = liftT $ A.postscanl f x
prescanr :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> Tensor dims e
prescanr f x = liftT $ A.prescanr f x
postscanr :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e)
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e
-> Tensor dims e
postscanr f x = liftT $ A.postscanr f x
-- TODO scans with segments
-- TODO stencils
rank :: forall dims e. SingI (Length dims) => Tensor dims e -> Int
rank _ = fromInteger $ fromSing (sing :: SNat (Length dims))
-- XXX uhm, you're supposed to use runN, so are these really a good idea?
-- possibly could have some kind of runN function that takes a function from
-- Tensors to a Tensor, as well as (non-acc) arrays as arguments
-- array.
fromFunction :: forall dims sh e.
(sh ~ ShapeOf dims, KnownShape dims, Shape sh, Elt e)
=> (sh -> e) -> Tensor dims e
fromFunction =
Tensor . use . A.fromFunction (shFromDims (Proxy :: Proxy dims))
-- TODO: this only exists in accelerate 1.2
-- fromFunctionM :: (sh ~ ShapeOf dims)
-- => (sh -> IO e) -> Tensor dims e
-- fromFunctionM f = do
-- array <- A.fromFunctionM (shFromDims (Proxy :: Proxy dims)) f
-- pure $ Tensor array
unsafeFromList :: forall dims e. (KnownShape dims, ShapeLike dims, Elt e)
=> [e] -> Tensor dims e
unsafeFromList = Tensor . use . fromList (shFromDims (Proxy :: Proxy dims))
-- examples
ta :: Tensor [3, 2] Double
ta = Tensor . use $ fromList (Z :. 3 :. 2 :: DIM2) [1..]
tb :: Tensor [3, 2] Double
tb = Tensor . use $ fromList (Z :. 3 :. 2 :: DIM2) [2..]
tc :: Tensor [3, 3] Double
tc = Tensor . use $ fromList (Z :. 3 :. 3 :: DIM2) [1..]
td :: Tensor [2, 3] Double
td = Tensor . use $ fromList (Z :. 2 :. 3 :: DIM2) [2..]
te :: Tensor [3, 5] Double
te = enumFromN (constant 2)
zipped :: Tensor [3, 2] Double
zipped = zipWith (+) ta tb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment