Last active
October 2, 2016 23:55
-
-
Save oliver-batchelor/568665c16506cbe2068123350cee9904 to your computer and use it in GitHub Desktop.
Attempts at static tensor dimensioning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskell, FlexibleContexts, FlexibleInstances, GADTs, DataKinds, | |
TypeInType, KindSignatures, InstanceSigs, TypeOperators, | |
ConstraintKinds, RankNTypes, ScopedTypeVariables, TypeFamilies, | |
UndecidableInstances, MultiParamTypeClasses, TypeApplications, PartialTypeSignatures #-} | |
--{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-} | |
-- Three attempts at implementing a 'concat' operation for arbitrary dimension tensors, | |
-- concat dim xs ys is valid only if tensor xs and tensor ys share the same shape | |
-- (except for the dimension being joined) | |
-- Fourth attempt (most promising), use type level peano numbers for the dimension, but now I have two types of Nat?! | |
-- First attempt uses a type function (implemented with singletons and wrapped up with a custom type error) | |
-- which seems to suffer from type inferece problems, e.g. if any of the inputs are not completely defined | |
-- GHC cannot see that the output tensor has the same dimensionality even. | |
-- Second attempt, write explicit constraints for each dimension e.g. concat0 concat1, using pattern matching directly | |
-- seems much better - but not quite as general. | |
-- Third attempt, try to get the best of both worlds by using typeclasses matching on the concat dimension | |
-- and instances with explicit constraints. | |
-- Seems not really better than the second method. | |
-- My plan is to create a neural network library (done right), and ship the heavy lifting off to external GPU libs | |
-- e.g. NVidias cuDNN | |
-- My ideas so far are to use observable sharing to create an expression tree, I've been butchering | |
-- the data-treify examples, which seem like it can do most of what I'd want it to do. | |
-- | |
-- Important concepts: | |
-- * Ability to have parameters and data separate from the network structure, | |
-- in order that it's not tied to one particular implementation (and for serialization) | |
-- * Run on well tested GPU implementation in batch mode e.g. cuDNN (possibly sub in accelerate for general array slicing and dicing) | |
-- * Annotation for nice graph-viz diagrams for academic papers would be nice, too. | |
-- So I'd love it if anyone has an improvement here or any ideas! | |
module HDNN.Type.Shape where | |
import Data.Singletons.TH | |
import Data.Singletons.Prelude | |
import Data.Singletons.Prelude.Maybe | |
import Prelude hiding ( take ) | |
import Data.Kind (Type) | |
import GHC.TypeLits | |
import GHC.TypeLits.List | |
import Data.Proxy | |
--Tensor defined by arbitrary dimensions and type, for now not implemented just | |
-- working out type inference | |
data Tensor (dims :: [Nat]) (a :: Type) = Tensor | |
type Vector (n :: Nat) a = Tensor '[n] a | |
type Matrix m n a = Tensor '[m, n] a | |
type Image h w c a = Tensor '[h, w, c] a | |
instance Show (Tensor ds a) where | |
show _ = "Tensor" | |
-- | |
-- Attempt 4, this seems to be the cleanest yet - except now I have two types of Nat around which | |
-- doesn't seem ideal, can I replace UNat with Nat somehow? There does not seem to be an easy way to | |
-- do the pattern match? | |
-- | |
$(singletons [d| | |
data N = Z | S N | |
deriving (Eq) | |
|]) | |
class Concats (d:: N) (xs::[Nat]) (ys::[Nat]) where | |
type ConcatShape d xs ys :: [Nat] | |
instance (xs ~ ys) => Concats Z (x:xs) (y:ys) where | |
type ConcatShape Zero (x:xs) (y:_) = (x + y : xs) | |
instance (x ~ y) => Concats SS (x:xs) (y:ys) where | |
type ConcatShape (Succ n) (x:xs) (_:ys) = x : ConcatShape n xs ys | |
dim0 = SZ | |
dim1 = SS dim0 | |
dim2 = SS dim1 | |
dim3 = SS dim2 | |
dim4 = SS dim3 | |
type D0 = Z | |
type D1 = S D0 | |
type D2 = S D1 | |
type D3 = S D2 | |
type D4 = S D3 | |
type Dim (d :: N) = Sing d | |
concat :: Dim d -> Tensor xs a -> Tensor ys a -> Tensor (ConcatShape d xs ys) a | |
concat _ Tensor Tensor = Tensor | |
concatMat :: (dim :< D2) ~ 'True => Dim dim -> Matrix m n a -> Matrix p q a -> Tensor (ConcatShape dim [m, n] [p, q]) a | |
concatMat = concat | |
concatVec :: Vector n a -> Vector m a -> Vector (n + m) a | |
concatVec = concat dim0 | |
baz :: _ | |
baz m1 m2 = concat dim1 m1 m2 | |
baz2 :: Matrix m m Float -> _ -> _ | |
baz2 m1 m2 = concatMat dim1 m1 m2 | |
baz3 :: Matrix 4 4 Float -> _ | |
baz3 m1 m2 = concatMat dim0 (baz2 m1 m2) x where | |
(x::_) = baz2 m1 m1 | |
-- Attempt 1: Using type function to compute the shape of concat dim x y | |
$(singletons [d| | |
fmapMaybe :: (a -> b) -> Maybe a -> Maybe b | |
fmapMaybe f Nothing = Nothing | |
fmapMaybe f (Just a) = Just (f a) | |
maybeConcat :: (Num i, Eq i) => i -> [i] -> [i] -> Maybe [i] | |
maybeConcat d (x:xs) (y:ys) | |
| d == 0 && xs == ys = Just (x + y : xs) | |
| x == y = (x:) `fmapMaybe` maybeConcat (d - 1) xs ys | |
| otherwise = Nothing | |
|]) | |
type family ConcatShape d ds ds' where | |
ConcatShape d ds ds' = FromMaybe (TypeError (Text "Cannot concat dimension " :<>: ShowType d :<>: Text " for (" | |
:<>: ShowType ds :<>: Text ", " :<>: ShowType ds' :<>: Text ")")) (MaybeConcat d ds ds') | |
-- Problematic for inference because there are many properties the compiler can't tell from | |
-- the type function - e.g. that the dimensions even have the same length | |
cat :: KnownNat dim => Proxy dim -> Tensor ds a -> Tensor ds' a -> Tensor (ConcatShape dim ds ds') a | |
cat _ Tensor Tensor = Tensor | |
-- Does not reduce type function sinc | |
-- Couldn't match type ‘Data.Singletons.Prelude.Maybe.Case_1628089981 | |
-- (TypeError ...) | |
-- (Case_1627432150 | |
-- 0 | |
-- d0 | |
-- .... | |
--cat0 :: ((ds :: [Nat]) ~ (ds' :: [Nat])) => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a | |
--cat0 = cat (Proxy @ 0) | |
-- OK | |
--Found type wildcard ‘_’ standing for ‘Tensor '[n, m + p] Float’ | |
foo :: Matrix n m Float -> Matrix n p Float -> _ | |
foo m1 m2 = cat (Proxy @ 1) m1 m2 | |
-- Cannot concat dimension 0 for (ds, ds') | |
-- (because ConcatShape does not imply ds and ds' have the same dimensions even) | |
--bar m1 m2 = cat (Proxy :: Proxy 0) m1 m2 | |
-- Cannot concat dimension 0 for ('[n, p + p], ds') | |
-- (Cannot infer that m2 is a Matrix) | |
--bar1 m1 m2 = cat (Proxy :: Proxy 0) (foo m1 m1) m2 | |
-- Attempt 2, most direct - this seems to have much better inference but less flexibility | |
concat0 :: (ds ~ ds') => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a | |
concat0 Tensor Tensor = Tensor | |
concat1 :: (d0 ~ d0', ds ~ ds') => Tensor (d0 : d1 : ds) a -> Tensor (d0' : d1' : ds') a -> Tensor (d0 : d1 + d1' : ds) a | |
concat1 Tensor Tensor = Tensor | |
-- OK | |
-- Found type wildcard ‘_’ | |
-- standing for ‘Matrix d0' p Float | |
-- -> Tensor '[d0', d1'] Float -> Tensor '[d0', (p + p) + d1'] Float’ | |
bar3 :: _ | |
bar3 m1 m2 = concat0 (concat1 m1 m1) m2 | |
-- Attempt 3, try to help the compiler by breaking down into dimensions | |
class ConcatShape2 d (xs :: [Nat]) (ys :: [Nat]) where | |
type ConcatDim d xs ys :: [Nat] | |
instance (xs ~ ys) => ConcatShape2 0 (d:ds) (d':ds') where | |
type ConcatDim 0 (d:ds) (d':ds') = (d + d' : ds) | |
instance (d0 ~ d0', ds ~ ds') => ConcatShape2 1 (d0 : d1 : ds) (d0' : d1' : ds') where | |
type ConcatDim 1 (d0 : d1 : ds) (d0' : d1' : ds') = (d0 : d1 + d1' : ds) | |
catClass :: ConcatShape2 d ds ds' => Proxy d -> Tensor ds a -> Tensor ds' a -> Tensor (ConcatDim d ds ds') a | |
catClass _ Tensor Tensor = Tensor | |
baz :: Matrix n m Float -> Matrix n p Float -> _ | |
baz m1 m2 = catClass (Proxy @ 1) m1 m2 | |
--These are OK without type signatures, but ambiguous when asked for the explicit type | |
-- baz2:: _ <-- causes error | |
-- Ambiguous type variables ‘ds0’, ‘ds'0’ arising from a use of ‘catClass’ | |
-- prevents the constraint ‘(ConcatShape2 | |
-- 0 ds0 ds'0)’ from being solved. | |
baz2 m1 m2 = catClass (Proxy @ 0) m1 m2 | |
baz3 m1 m2 = catClass (Proxy @ 0) (baz m2 m1) m2 | |
-- At least this way we can be persuaded it | |
catClass0 :: (ds ~ ds') => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a | |
catClass0 = catClass (Proxy @ 0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment