-
-
Save mstksg/8a923bfb7a7c67d3316275950c34aada to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE NoImplicitPrelude #-} | |
{-# LANGUAGE OverloadedLists #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
{-# LANGUAGE PolyKinds #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeInType #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
import Data.Distributive | |
import Data.Functor.Rep | |
import Data.List ((!!)) | |
import Data.Singletons | |
import Data.Singletons.Prelude | |
import Data.Singletons.TH (promote) | |
import Data.Singletons.Prelude.List | |
import Data.Singletons.TypeLits | |
import GHC.Exts | |
import GHC.Show | |
import GHC.TypeLits | |
import Protolude hiding (show, (<.>)) | |
import qualified Data.Vector as V | |
import Data.Foldable | |
import Data.Kind | |
import Data.Maybe | |
newtype Tensor (r::[Nat]) a = Tensor { v :: V.Vector a } | |
deriving (Functor, Eq, Foldable) | |
data ProdMap :: (a -> b -> Type) -> [a] -> [b] -> Type where | |
PMZ :: ProdMap f '[] '[] | |
PMS :: f a b -> ProdMap f as bs -> ProdMap f (a ': as) (b ': bs) | |
data Slice :: Nat -> Nat -> Type where | |
Slice :: Sing l -> Sing c -> Sing r -> Slice (l + c + r) c | |
slice | |
:: (SingI ns, SingI ms) | |
=> ProdMap Slice ns ms | |
-> Tensor ns a | |
-> Tensor ms a | |
slice = undefined | |
data IsLTE :: Nat -> Nat -> Type where | |
IsLTE :: (n <= m) => Sing n -> Sing m -> IsLTE n m | |
-- given a type-level list `ns` of the number of items from each dimension to take, | |
-- returns the `ProdMap Slice ms ns` that encodes that. | |
sliceHeads :: forall ms ns. ProdMap IsLTE ns ms ProdMap Slice ms ns | |
sliceHeads = \case | |
PMZ -> PMZ | |
IsLTE sN sM `PMS` ltes -> | |
Slice (SNat @0) s (sM %:- sN) `PMS` sliceHeads ss | |
data CurrySing :: k1 -> k2 -> Type where | |
CurrySing :: Sing a -> Sing b -> CurrySing a b | |
mkLTE :: ProdMap CurrySing ns ms -> Maybe (ProdMap IsLTE ns ms) | |
mkLTE = \case | |
PMZ -> Just PMZ | |
CurrySing (n :: Sing n) (m :: Sing m) `PMS` pms -> case | |
-- %<=? from http://hackage.haskell.org/package/typelits-witnesses-0.2.3.0/docs/GHC-TypeLits-Compare.html | |
(Proxy @n) %<=? (Proxy @m) of | |
LE Refl -> IsLTE n m <$> mkLTE pms | |
_ -> Nothing | |
zipSings | |
:: Sing as | |
-> Sing bs | |
-> Maybe (ProdMap CurrySing as bs) | |
zipSings = \case | |
SNil -> \case | |
SNil -> Just PMZ | |
_ `SCons` _ -> Nothing | |
sA `SCons` sAs -> \case | |
SNil -> Nothing | |
sB `SCons` sBs -> | |
(CurrySing sA sB `PMS`) <$> zipSings sAs sBs | |
headsFromList | |
:: SingI ms | |
=> [Integer] | |
-> Tensor ms a | |
-> (forall ns. SingI ns => Tensor ns a -> r) | |
-> r | |
headsFromList ns t f = withSomeSing ns $ \nsSing -> | |
withSingI nsSing $ | |
case zipSings nsSing (sing @_ @ms) of | |
Nothing -> error "dimensions don't line up" | |
Just nsms -> case mkLTE nsms of | |
Nothing -> error "dimensions out of range" | |
Just lte -> f (slice (sliceHeads lte) t) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment