Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
data ProdMap :: (a -> b -> Type) -> [a] -> [b] -> Type where
PMZ :: ProdMap f '[] '[]
PMS :: f a b -> ProdMap f as bs -> ProdMap f (a ': as) (b ': bs)
data Slice :: Nat -> Nat -> Type where
Slice :: Sing l -> Sing c -> Sing r -> Slice (l + c + r) c
slice
:: (SingI ns, SingI ms)
=> ProdMap Slice ns ms
-> Tensor ns
-> Tensor ms
-- given a type-level list `ns` of the number of items from each dimension to take,
-- returns the `ProdMap Slice ms ns` that encodes that.
sliceHeads :: Sing ns -> ProdMap Slice ms ns
sliceHeads = \case
SNil -> PMZ
s@SNat `SCons` ss -> Slice (SNat @0) s meh `PMS` sliceHeads ss
-- meh has to be :: Sing (m - n), and positive, so in real life, we'd
-- have to also iterate over ms as well.
headsFromList
:: SingI ms
=> [Integer]
-> Tensor ms
-> (forall ns. SingI ns => Tensor ns -> r)
-> r
headsFromList ns t f = withSomeSing ns $ \nsSing ->
withSingI nsSing $
f (slice (sliceHeads nsSing) t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment