Skip to content

Instantly share code, notes, and snippets.

@oliver-batchelor
Last active October 2, 2016 23:55
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oliver-batchelor/568665c16506cbe2068123350cee9904 to your computer and use it in GitHub Desktop.
Save oliver-batchelor/568665c16506cbe2068123350cee9904 to your computer and use it in GitHub Desktop.
Attempts at static tensor dimensioning
{-# LANGUAGE TemplateHaskell, FlexibleContexts, FlexibleInstances, GADTs, DataKinds,
TypeInType, KindSignatures, InstanceSigs, TypeOperators,
ConstraintKinds, RankNTypes, ScopedTypeVariables, TypeFamilies,
UndecidableInstances, MultiParamTypeClasses, TypeApplications, PartialTypeSignatures #-}
--{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise -fplugin GHC.TypeLits.KnownNat.Solver #-}
-- Three attempts at implementing a 'concat' operation for arbitrary dimension tensors,
-- concat dim xs ys is valid only if tensor xs and tensor ys share the same shape
-- (except for the dimension being joined)
-- Fourth attempt (most promising), use type level peano numbers for the dimension, but now I have two types of Nat?!
-- First attempt uses a type function (implemented with singletons and wrapped up with a custom type error)
-- which seems to suffer from type inferece problems, e.g. if any of the inputs are not completely defined
-- GHC cannot see that the output tensor has the same dimensionality even.
-- Second attempt, write explicit constraints for each dimension e.g. concat0 concat1, using pattern matching directly
-- seems much better - but not quite as general.
-- Third attempt, try to get the best of both worlds by using typeclasses matching on the concat dimension
-- and instances with explicit constraints.
-- Seems not really better than the second method.
-- My plan is to create a neural network library (done right), and ship the heavy lifting off to external GPU libs
-- e.g. NVidias cuDNN
-- My ideas so far are to use observable sharing to create an expression tree, I've been butchering
-- the data-treify examples, which seem like it can do most of what I'd want it to do.
--
-- Important concepts:
-- * Ability to have parameters and data separate from the network structure,
-- in order that it's not tied to one particular implementation (and for serialization)
-- * Run on well tested GPU implementation in batch mode e.g. cuDNN (possibly sub in accelerate for general array slicing and dicing)
-- * Annotation for nice graph-viz diagrams for academic papers would be nice, too.
-- So I'd love it if anyone has an improvement here or any ideas!
module HDNN.Type.Shape where
import Data.Singletons.TH
import Data.Singletons.Prelude
import Data.Singletons.Prelude.Maybe
import Prelude hiding ( take )
import Data.Kind (Type)
import GHC.TypeLits
import GHC.TypeLits.List
import Data.Proxy
--Tensor defined by arbitrary dimensions and type, for now not implemented just
-- working out type inference
data Tensor (dims :: [Nat]) (a :: Type) = Tensor
type Vector (n :: Nat) a = Tensor '[n] a
type Matrix m n a = Tensor '[m, n] a
type Image h w c a = Tensor '[h, w, c] a
instance Show (Tensor ds a) where
show _ = "Tensor"
--
-- Attempt 4, this seems to be the cleanest yet - except now I have two types of Nat around which
-- doesn't seem ideal, can I replace UNat with Nat somehow? There does not seem to be an easy way to
-- do the pattern match?
--
$(singletons [d|
data N = Z | S N
deriving (Eq)
|])
class Concats (d:: N) (xs::[Nat]) (ys::[Nat]) where
type ConcatShape d xs ys :: [Nat]
instance (xs ~ ys) => Concats Z (x:xs) (y:ys) where
type ConcatShape Zero (x:xs) (y:_) = (x + y : xs)
instance (x ~ y) => Concats SS (x:xs) (y:ys) where
type ConcatShape (Succ n) (x:xs) (_:ys) = x : ConcatShape n xs ys
dim0 = SZ
dim1 = SS dim0
dim2 = SS dim1
dim3 = SS dim2
dim4 = SS dim3
type D0 = Z
type D1 = S D0
type D2 = S D1
type D3 = S D2
type D4 = S D3
type Dim (d :: N) = Sing d
concat :: Dim d -> Tensor xs a -> Tensor ys a -> Tensor (ConcatShape d xs ys) a
concat _ Tensor Tensor = Tensor
concatMat :: (dim :< D2) ~ 'True => Dim dim -> Matrix m n a -> Matrix p q a -> Tensor (ConcatShape dim [m, n] [p, q]) a
concatMat = concat
concatVec :: Vector n a -> Vector m a -> Vector (n + m) a
concatVec = concat dim0
baz :: _
baz m1 m2 = concat dim1 m1 m2
baz2 :: Matrix m m Float -> _ -> _
baz2 m1 m2 = concatMat dim1 m1 m2
baz3 :: Matrix 4 4 Float -> _
baz3 m1 m2 = concatMat dim0 (baz2 m1 m2) x where
(x::_) = baz2 m1 m1
-- Attempt 1: Using type function to compute the shape of concat dim x y
$(singletons [d|
fmapMaybe :: (a -> b) -> Maybe a -> Maybe b
fmapMaybe f Nothing = Nothing
fmapMaybe f (Just a) = Just (f a)
maybeConcat :: (Num i, Eq i) => i -> [i] -> [i] -> Maybe [i]
maybeConcat d (x:xs) (y:ys)
| d == 0 && xs == ys = Just (x + y : xs)
| x == y = (x:) `fmapMaybe` maybeConcat (d - 1) xs ys
| otherwise = Nothing
|])
type family ConcatShape d ds ds' where
ConcatShape d ds ds' = FromMaybe (TypeError (Text "Cannot concat dimension " :<>: ShowType d :<>: Text " for ("
:<>: ShowType ds :<>: Text ", " :<>: ShowType ds' :<>: Text ")")) (MaybeConcat d ds ds')
-- Problematic for inference because there are many properties the compiler can't tell from
-- the type function - e.g. that the dimensions even have the same length
cat :: KnownNat dim => Proxy dim -> Tensor ds a -> Tensor ds' a -> Tensor (ConcatShape dim ds ds') a
cat _ Tensor Tensor = Tensor
-- Does not reduce type function sinc
-- Couldn't match type ‘Data.Singletons.Prelude.Maybe.Case_1628089981
-- (TypeError ...)
-- (Case_1627432150
-- 0
-- d0
-- ....
--cat0 :: ((ds :: [Nat]) ~ (ds' :: [Nat])) => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a
--cat0 = cat (Proxy @ 0)
-- OK
--Found type wildcard ‘_’ standing for ‘Tensor '[n, m + p] Float’
foo :: Matrix n m Float -> Matrix n p Float -> _
foo m1 m2 = cat (Proxy @ 1) m1 m2
-- Cannot concat dimension 0 for (ds, ds')
-- (because ConcatShape does not imply ds and ds' have the same dimensions even)
--bar m1 m2 = cat (Proxy :: Proxy 0) m1 m2
-- Cannot concat dimension 0 for ('[n, p + p], ds')
-- (Cannot infer that m2 is a Matrix)
--bar1 m1 m2 = cat (Proxy :: Proxy 0) (foo m1 m1) m2
-- Attempt 2, most direct - this seems to have much better inference but less flexibility
concat0 :: (ds ~ ds') => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a
concat0 Tensor Tensor = Tensor
concat1 :: (d0 ~ d0', ds ~ ds') => Tensor (d0 : d1 : ds) a -> Tensor (d0' : d1' : ds') a -> Tensor (d0 : d1 + d1' : ds) a
concat1 Tensor Tensor = Tensor
-- OK
-- Found type wildcard ‘_’
-- standing for ‘Matrix d0' p Float
-- -> Tensor '[d0', d1'] Float -> Tensor '[d0', (p + p) + d1'] Float’
bar3 :: _
bar3 m1 m2 = concat0 (concat1 m1 m1) m2
-- Attempt 3, try to help the compiler by breaking down into dimensions
class ConcatShape2 d (xs :: [Nat]) (ys :: [Nat]) where
type ConcatDim d xs ys :: [Nat]
instance (xs ~ ys) => ConcatShape2 0 (d:ds) (d':ds') where
type ConcatDim 0 (d:ds) (d':ds') = (d + d' : ds)
instance (d0 ~ d0', ds ~ ds') => ConcatShape2 1 (d0 : d1 : ds) (d0' : d1' : ds') where
type ConcatDim 1 (d0 : d1 : ds) (d0' : d1' : ds') = (d0 : d1 + d1' : ds)
catClass :: ConcatShape2 d ds ds' => Proxy d -> Tensor ds a -> Tensor ds' a -> Tensor (ConcatDim d ds ds') a
catClass _ Tensor Tensor = Tensor
baz :: Matrix n m Float -> Matrix n p Float -> _
baz m1 m2 = catClass (Proxy @ 1) m1 m2
--These are OK without type signatures, but ambiguous when asked for the explicit type
-- baz2:: _ <-- causes error
-- Ambiguous type variables ‘ds0’, ‘ds'0’ arising from a use of ‘catClass’
-- prevents the constraint ‘(ConcatShape2
-- 0 ds0 ds'0)’ from being solved.
baz2 m1 m2 = catClass (Proxy @ 0) m1 m2
baz3 m1 m2 = catClass (Proxy @ 0) (baz m2 m1) m2
-- At least this way we can be persuaded it
catClass0 :: (ds ~ ds') => Tensor (d0 ': ds) a -> Tensor (d0' : ds') a -> Tensor (d0 + d0' : ds) a
catClass0 = catClass (Proxy @ 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment