Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A half-working dependently typed tensor library based on the linear package
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Foo where
import Data.Constraint
import Data.Foldable
import Data.Kind
import Data.Singletons
import Data.Singletons.Prelude (Fmap, Product)
import Data.Singletons.TypeLits ()
import qualified Data.Vector as Vector
import GHC.TypeLits (type (+), KnownNat, Nat, SomeNat(SomeNat), natVal, someNatVal)
import Linear
import Linear.V
import Unsafe.Coerce (unsafeCoerce)
----------------- Fin Stuff ---------------------------
newtype Fin (n :: k) = Fin
{ unFin :: Int
} deriving stock (Show)
fin :: forall k (n :: k). Dim n => Int -> Maybe (Fin n)
fin i = if reflectDim (Proxy @n) > i then Just (Fin i) else Nothing
finUnsafe :: forall k (n :: k). Dim n => Int -> Fin n
finUnsafe i =
case fin i of
Nothing -> error "finUnsafe"
Just res -> res
---------------- HList stuff --------------------------
data HList (as :: [k]) where
EmptyHList :: HList '[]
ConsHList :: x -> HList xs -> HList (x ': xs)
---------------- Vector Stuff --------------------------
replicateVec :: forall n a. Dim n => a -> V n a
replicateVec a = V $ Vector.replicate (reflectDim (Proxy @n)) a
reifyDim' :: Int -> (forall k (n :: k). Dim n => Proxy n -> r) -> r
reifyDim' i f =
case someNatVal (fromIntegral i) of
Nothing -> error "lalala"
Just (SomeNat proxy) -> f @Nat proxy
indexVec :: forall k (n :: k) a. Fin n -> V n a -> a
indexVec (Fin i) (V vec) = vec Vector.! i
fromListVec :: forall k (n :: k) a. Dim n => [a] -> Maybe (V n a)
fromListVec as =
let vec = Vector.fromList as
len = Vector.length vec
di = reflectDim (Proxy @n)
in
if di <= len then Just (V $ Vector.take di vec) else Nothing
fromListVecUnsafe :: forall k (n :: k) a. Dim n => [a] -> V n a
fromListVecUnsafe as =
case fromListVec as of
Nothing -> error "fromListVecUnsafe"
Just res -> res
dropVec :: forall m n a. KnownNat m => V (m + n) a -> V n a
dropVec (V vec) = V $ Vector.drop (fromIntegral $ natVal (Proxy @m)) vec
genVec :: forall n a. Dim n => (Fin n -> a) -> V n a
genVec f = V $ Vector.generate (reflectDim (Proxy @n)) (\i -> f $ finUnsafe i)
---------------- Matrix Stuff --------------------------
newtype Matrix (ns :: [k]) (a :: Type) = Matrix
{ unMatrix :: V (Product ns) a
} deriving stock Show
eqMatrix :: Eq a => Matrix ns a -> Matrix ns a -> Bool
eqMatrix (Matrix v1) (Matrix v2) = v1 == v2
fmapMatrix :: (a -> b) -> Matrix ns a -> Matrix ns b
fmapMatrix a2b (Matrix v) = Matrix $ fmap a2b v
replicateMatrix :: forall (ns :: [k]) a. Dims ns => a -> Matrix ns a
replicateMatrix a =
case prodDimFromDims @_ @ns of
Sub Dict -> Matrix (replicateVec @(Product ns) a)
genMatrix :: forall k (ns :: [k]) a. Dims ns => (HList (Fmap Fin ns) -> a) -> Matrix ns a
genMatrix f =
Matrix $
V $
Vector.generate
(product (reflectDims (Proxy @ns)))
(\i -> f _)
prodDimFromDims :: forall k (ns :: [k]). Dims ns :- Dim (Product ns)
prodDimFromDims =
Sub $
let prod = product $ reflectDims (Proxy @ns)
in reifyDim prod f
where
f :: forall m x. Dim m => Proxy m -> Dict (Dim x)
f _ = unsafeCoerce (Dict :: Dict (Dim m))
allDimFromDims :: forall (ns :: [k]). Dims ns :- AllC Dim ns
allDimFromDims =
Sub $
case reflectDims (Proxy @ns) of
[] -> unsafeCoerce (Dict :: Dict ())
(h:t) ->
reifyDimNat h (f t)
where
f :: forall (m :: Nat). Dim m => [Int] -> Proxy m -> Dict (AllC Dim ns)
f t Proxy =
reifyDimsNat t g
where
g :: forall (ms :: [Nat]). Dims ms => Proxy ms -> Dict (AllC Dim ns)
g Proxy =
case allDimFromDims @_ @ms of
Sub (Dict :: Dict (AllC Dim ms)) ->
case proveAll (Proxy @Dim) (Proxy @m) (Proxy @ms) of
Sub Dict -> unsafeCoerce (Dict :: Dict (AllC Dim (m ': ms)))
testtest :: forall n m o. Dims '[n, m, o] => Proxy '[n, m, o] -> Int
testtest _ =
case allDimFromDims @_ @'[n, m, o] of
Sub Dict -> reflectDim (Proxy @m)
data MyNat where
MyZero :: MyNat
MySucc :: MyNat -> MyNat
instance Dim MyZero where
reflectDim :: forall p. p MyZero -> Int
reflectDim _ = 0
instance Dim n => Dim (MySucc n) where
reflectDim :: forall p. p (MySucc n) -> Int
reflectDim _ = reflectDim (Proxy @n) + 1
type MyOne = 'MySucc 'MyZero
type MyTwo = 'MySucc MyOne
type MyThree = 'MySucc MyTwo
type MyFour = 'MySucc MyThree
class Dims ns where
reflectDims :: forall p. p ns -> [Int]
instance Dim n => Dims (V n a) where
reflectDims _ = [reflectDim (Proxy @n)]
instance Dims ns => Dims (Matrix ns a) where
reflectDims _ = reflectDims (Proxy @ns)
instance Dims '[] where
reflectDims _ = []
instance (Dim n, Dims ns) => Dims (n ': ns) where
reflectDims _ = reflectDim (Proxy @n) : reflectDims (Proxy @ns)
dims :: forall ns a. Dims ns => Matrix ns a -> [Int]
dims _ = reflectDims (Proxy @ns)
reifyDimsNat :: forall r. [Int] -> (forall (ns :: [Nat]). Dims ns => Proxy ns -> r) -> r
reifyDimsNat [] f = f @'[] Proxy
reifyDimsNat (h:t) f =
reifyDimNat h go1
where
go1 :: forall (m :: Nat). KnownNat m => Proxy m -> r
go1 Proxy =
reifyDimsNat t go2
where
go2 :: forall (ms :: [Nat]). Dims ms => Proxy ms -> r
go2 Proxy =
case knownNatToDim @m of
Sub Dict -> f (Proxy :: Proxy (m ': ms))
reifyDims :: forall r. [Int] -> (forall k (ns :: [k]). Dims ns => Proxy ns -> r) -> r
reifyDims [] f = f @_ @'[] Proxy
reifyDims (h:t) f =
reifyDimNat h go1
where
go1 :: forall (m :: Nat). KnownNat m => Proxy m -> r
go1 Proxy =
reifyDimsNat t go2
where
go2 :: forall (ms :: [Nat]). Dims ms => Proxy ms -> r
go2 Proxy =
case knownNatToDim @m of
Sub Dict -> f @Nat (Proxy :: Proxy (m ': ms))
knownNatToDim :: KnownNat n :- Dim n
knownNatToDim = Sub Dict
type family AllC (constraint :: k -> Constraint) (as :: [k]) :: Constraint where
AllC _ '[] = ()
AllC constraint (a ': as) = (constraint a, AllC constraint as)
proveAll
:: forall (n :: k) (ns :: [k]) (constraint :: k -> Constraint) proxy1 proxy2 proxy3
. proxy1 constraint
-> proxy2 n
-> proxy3 ns
-> (constraint n, AllC constraint ns) :- AllC constraint (n ': ns)
proveAll _ _ _ = Sub Dict
let
nixpkgsSrc = builtins.fetchTarball {
# nixpkgs-unstable as of 2019/05/30.
url = "https://github.com/NixOS/nixpkgs/archive/eccb90a2d99.tar.gz";
sha256 = "0ffa84mp1fgmnqx2vn43q9pypm3ip9y67dkhigsj598d8k1chzzw";
};
nixpkgs = import nixpkgsSrc {};
haskellPkgs = nixpkgs.haskellPackages;
ghcEnv = haskellPkgs.ghcWithPackages (pkgs: with pkgs; [
constraints
linear
singletons
]);
in
nixpkgs.mkShell {
buildInputs = [ ghcEnv ];
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment