Skip to content

Instantly share code, notes, and snippets.

@mrkgnao
Forked from thoughtpolice/LICENSE.txt
Created June 25, 2017 09:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrkgnao/c32459f69effe10bf8b0b87b47e404f9 to your computer and use it in GitHub Desktop.
Save mrkgnao/c32459f69effe10bf8b0b87b47e404f9 to your computer and use it in GitHub Desktop.
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Naperian where
import qualified Prelude
import Prelude hiding ( lookup, length, replicate, zipWith )
import qualified Data.IntMap as IntMap
import Data.List ( intercalate )
import Data.Kind ( Type, Constraint )
import Control.Applicative ( liftA2 )
import qualified GHC.Exts as L (IsList(..))
import GHC.Prim
import GHC.TypeLits
import qualified Data.Vector as Vector
import Data.Foldable ( toList )
--------------------------------------------------------------------------------
-- Miscellaneous
-- | The finite set of type-bounded Naturals. A value of type @'Fin' n@ has
-- exactly @n@ inhabitants, the natural numbers from @[0..n-1]@.
data Finite :: Nat -> Type where
Fin :: Int -> Finite n
deriving (Eq, Show)
-- | Create a type-bounded finite number @'Fin' n@ from a runtime integer,
-- bounded to a statically known limit. If the input value @x > n@, then
-- @'Nothing'@ is returned. Otherwise, returns @'Just' (x :: 'Fin' n)@.
finite :: forall n. KnownNat n => Int -> Maybe (Finite n)
finite x = case (x > y) of
True -> Nothing
False -> Just (Fin x)
where y = fromIntegral (natVal' (proxy# :: Proxy# n))
-- | \"'Applicative' zipping\".
azipWith :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
azipWith h xs ys = (pure h <*> xs) <*> ys
-- | Format a vector to make it look nice.
showVector :: [String] -> String
showVector xs = "<" ++ intercalate "," xs ++ ">"
--------------------------------------------------------------------------------
-- Pairs
-- | The cartesian product of @'a'@, equivalent to @(a, a)@.
data Pair a = Pair a a
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
instance Applicative Pair where
pure a = Pair a a
Pair k g <*> Pair a b = Pair (k a) (g b)
--------------------------------------------------------------------------------
-- Vectors
newtype Vector (n :: Nat) a = Vector (Vector.Vector a)
deriving (Eq, Ord, Functor, Foldable, Traversable)
instance Show a => Show (Vector n a) where
show = showVector . map show . toList
instance KnownNat n => Applicative (Vector n) where
pure = replicate
(<*>) = zipWith ($)
instance (KnownNat n, Traversable (Vector n)) => L.IsList (Vector n a) where
type Item (Vector n a) = a
toList = Data.Foldable.toList
fromList xs = case fromList xs of
Nothing -> error "Demanded vector of a list that wasn't the proper length"
Just ys -> ys
tail :: Vector (n + 1) a -> Vector n a
tail (Vector v) = Vector (Vector.tail v)
fromList :: forall n a. KnownNat n => [a] -> Maybe (Vector n a)
fromList xs = case (Prelude.length xs == sz) of
False -> Nothing
True -> Just (Vector $ Vector.fromList xs)
where sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith f (Vector a) (Vector b) = Vector (Vector.zipWith f a b)
length :: forall n a. KnownNat n => Vector n a -> Int
length _ = fromIntegral $ natVal' (proxy# :: Proxy# n)
replicate :: forall n a. KnownNat n => a -> Vector n a
replicate v = Vector (Vector.replicate sz v) where
sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
index :: Vector n a -> Finite n -> a
index (Vector v) (Fin n) = (Vector.!) v n
viota :: forall n. KnownNat n => Vector n (Finite n)
viota = Vector (fmap Fin (Vector.enumFromN 0 sz)) where
sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
--------------------------------------------------------------------------------
-- Naperian functors
-- | Naperian functors.
-- A useful way of thinking about a Naperian functor is that if we have a value
-- of type @v :: f a@ for some @'Naperian' f@, then we can think of @f a@ as a
-- bag of objects, with the ability to pick out the @a@ values inside the bag,
-- for each and every @a@ inside @f@. For example, in order to look up a value
-- @a@ inside a list @[a]@, we could use a function @[a] -> Int -> a@, which is
-- exactly @'(Prelude.!!)'@
--
-- The lookup function acts like a logarithm of the @'Functor' f@. Intuitively,
-- a Haskell function @f :: a -> b@ acts like the exponential @b^a@ if we intuit
-- types as an algebraic quantity. The logarithm of some value @x = b^a@ is
-- defined as @log_b(x) = a@, so given @x@ and a base @b@, it finds the exponent
-- @a@. In Haskell terms, this would be like finding the input value @a@ to a
-- function @f :: a -> b@, given a @b@, so it is a reverse mapping from the
-- outputs of @f@ back to its inputs.
--
-- A @'Naperian'@ functor @f@ is precisely a functor @f@ such that for any value
-- of type @f a@, we have a way of finding every single @a@ inside.
class Functor f => Naperian f where
{-# MINIMAL lookup, (tabulate | positions) #-}
-- | The \"logarithm\" of @f@. This type represents the 'input' you use to
-- look up values inside @f a@. For example, if you have a list @[a]@, and
-- you want to look up a value, then you use an @'Int'@ to index into
-- the list. In this case, @'Log' [a] = Int@. If you have a type-bounded
-- Vector @'Vector' (n :: 'Nat') a@, then @'Log' ('Vector' n)@ is the
-- range of integers @[0..n-1]@ (represented here as @'Finite' n@.)
type Log f
-- | Look up an element @a@ inside @f a@. If you read this function type in
-- english, it says \"if you give me an @f a@, then I will give you a
-- function, so you can look up the elements of @f a@ and get back an @a@\"
lookup :: f a -> (Log f -> a)
-- | Tabulate a @'Naperian'@. This creates @f a@ values by mapping the logarithm
-- of @f@ onto every \"position\" inside @f a@
tabulate :: (Log f -> a) -> f a
tabulate h = fmap h positions
-- | Find every position in the \"space\" of the @'Naperian' f@.
positions :: f (Log f)
positions = tabulate id
-- | The transposition of two @'Naperian'@ functors @f@ and @g@.
transpose :: (Naperian f, Naperian g) => f (g a) -> g (f a)
transpose = tabulate . fmap tabulate . flip . fmap lookup . lookup
instance Naperian Pair where
type Log Pair = Bool
lookup (Pair x y) b = if b then y else x
positions = Pair False True
instance KnownNat n => Naperian (Vector n) where
type Log (Vector n) = Finite n
lookup = index
positions = viota
--------------------------------------------------------------------------------
-- Dimensions
class (Applicative f, Naperian f, Traversable f) => Dimension f where
size :: f a -> Int
size = Prelude.length . toList
instance Dimension Pair where size = const 2
instance KnownNat n => Dimension (Vector n) where size = length
inner :: (Num a, Dimension f) => f a -> f a -> a
inner xs ys = sum (liftA2 (*) xs ys)
matrix :: ( Num a
, Dimension f
, Dimension g
, Dimension h
) => f (g a)
-> g (h a)
-> f (h a)
matrix xss yss = liftA2 (liftA2 inner) (fmap pure xss) (pure (transpose yss))
--------------------------------------------------------------------------------
-- Hyper-dimensional stuff
-- | Arbitrary-rank Hypercuboids, parameterized over their dimension.
data Hyper :: [Type -> Type] -> Type -> Type where
Scalar :: a -> Hyper '[] a
Prism :: (Dimension f, Shapely fs) => Hyper fs (f a) -> Hyper (f : fs) a
point :: Hyper '[] a -> a
point (Scalar a) = a
crystal :: Hyper (f : fs) a -> Hyper fs (f a)
crystal (Prism x) = x
instance Show a => Show (Hyper fs a) where
show = showHyper . fmap show where
showHyper :: Hyper gs String -> String
showHyper (Scalar s) = s
showHyper (Prism x) = showHyper (fmap (showVector . toList) x)
{--
class HyperLift f fs where
hyper :: (Shapely fs, Dimension f) => f a -> Hyper (f : fs) a
instance HyperLift f '[] where
hyper = Prism . Scalar
instance (Shapely fs, HyperLift f fs) => HyperLift f (f : fs) where
hyper = Prism . (\x -> (hyper $ _ x))
--}
class Shapely fs where
hreplicate :: a -> Hyper fs a
hsize :: Hyper fs a -> Int
instance Shapely '[] where
hreplicate a = Scalar a
hsize = const 1
instance (Dimension f, Shapely fs) => Shapely (f : fs) where
hreplicate a = Prism (hreplicate (pure a))
hsize (Prism x) = size (first x) * hsize x
instance Functor (Hyper fs) where
fmap f (Scalar a) = Scalar (f a)
fmap f (Prism x) = Prism (fmap (fmap f) x)
instance Shapely fs => Applicative (Hyper fs) where
pure = hreplicate
(<*>) = hzipWith ($)
hzipWith :: (a -> b -> c) -> Hyper fs a -> Hyper fs b -> Hyper fs c
hzipWith f (Scalar a) (Scalar b) = Scalar (f a b)
hzipWith f (Prism x) (Prism y) = Prism (hzipWith (azipWith f) x y)
first :: Shapely fs => Hyper fs a -> a
first (Scalar a) = a
first (Prism x) = head (toList (first x))
-- | Generalized transposition over arbitrary-rank hypercuboids.
transposeH :: Hyper (f : (g : fs)) a
-> Hyper (g : (f : fs)) a
transposeH (Prism (Prism x)) = Prism (Prism (fmap transpose x))
-- | Fold over a single dimension of a Hypercuboid.
foldrH :: (a -> a -> a) -> a -> Hyper (f : fs) a -> Hyper fs a
foldrH f z (Prism x) = fmap (foldr f z) x
-- | Lift an unary function from values to hypercuboids of values.
unary :: Shapely fs => (a -> b) -> (Hyper fs a -> Hyper fs b)
unary = fmap
-- | Lift a binary function from values to two sets of hypercuboids, which can
-- be aligned properly.
binary :: ( Compatible fs gs
, Max fs gs ~ hs
, Alignable fs hs
, Alignable gs hs
) => (a -> b -> c)
-> Hyper fs a
-> Hyper gs b
-> Hyper hs c
binary f x y = hzipWith f (align x) (align y)
up :: (Shapely fs, Dimension f) => Hyper fs a -> Hyper (f : fs) a
up = Prism . fmap pure
-- | Generalized, rank-polymorphic inner product.
innerH :: ( Max fs gs ~ (f : hs)
, Alignable fs (f : hs)
, Alignable gs (f : hs)
, Compatible fs gs
, Num a
) => Hyper fs a
-> Hyper gs a
-> Hyper hs a
innerH xs ys = foldrH (+) 0 (binary (*) xs ys)
-- | Generalized, rank-polymorphic matrix product.
matrixH :: ( Num a
, Dimension f
, Dimension g
, Dimension h
) => Hyper '[ g, f ] a
-> Hyper '[ h, g ] a
-> Hyper '[ h, f ] a
matrixH x y = case (crystal x, transposeH y) of
(xs, Prism (Prism ys)) -> hzipWith inner (up xs) (Prism (up ys))
--------------------------------------------------------------------------------
-- Alignment
class (Shapely fs, Shapely gs) => Alignable fs gs where
align :: Hyper fs a -> Hyper gs a
instance Alignable '[] '[] where
align = id
instance (Dimension f, Alignable fs gs) => Alignable (f : fs) (f : gs) where
align (Prism x) = Prism (align x)
instance (Dimension f, Shapely fs) => Alignable '[] (f : fs) where
align (Scalar a) = hreplicate a
type family Max (fs :: [Type -> Type]) (gs :: [Type -> Type]) :: [Type -> Type] where
Max '[] '[] = '[]
Max '[] (f : gs) = f : gs
Max (f : fs) '[] = f : fs
Max (f : fs) (f : gs) = f : Max fs gs
type family Compatible (fs :: [Type -> Type]) (gs :: [Type -> Type]) :: Constraint where
Compatible '[] '[] = ()
Compatible '[] (f : gs) = ()
Compatible (f : fs) '[] = ()
Compatible (f : fs) (f : gs) = Compatible fs gs
Compatible a b = TypeError (
'Text "Mismatched dimensions!"
':$$: 'Text "The dimension " ':<>: 'ShowType a ':<>: 'Text " can't be aligned with"
':$$: 'Text "the dimension " ':<>: 'ShowType b)
--------------------------------------------------------------------------------
-- Flattened, sparse Hypercuboids
elements :: Shapely fs => Hyper fs a -> [a]
elements (Scalar a) = [a]
elements (Prism a) = concat (map toList (elements a))
data Flat fs a where
Flat :: Shapely fs => Vector.Vector a -> Flat fs a
instance Functor (Flat fs) where
fmap f (Flat v) = Flat (fmap f v)
instance Show a => Show (Flat fs a) where
show = showHyper . fmap show where
showHyper :: Flat gs String -> String
showHyper (Flat v) = showVector (toList v)
flatten :: Shapely fs => Hyper fs a -> Flat fs a
flatten hs = Flat (Vector.fromList (elements hs))
data Sparse fs a where
Sparse :: Shapely fs => a -> IntMap.IntMap a -> Sparse fs a
unsparse :: forall fs a. Shapely fs => Sparse fs a -> Flat fs a
unsparse (Sparse e xs) = Flat (Vector.unsafeAccum (flip const) vs as)
where
as = IntMap.assocs xs
vs = Vector.replicate l e
l = hsize (hreplicate () :: Hyper fs ())
--------------------------------------------------------------------------------
-- Examples
type Matrix n m v = Vector n (Vector m v)
example1 :: Int
example1 = inner v1 v2 where
v1 = [ 1, 2, 3 ] :: Vector 3 Int
v2 = [ 4, 5, 6 ] :: Vector 3 Int
example2 :: Matrix 2 2 Int
example2 = matrix m1 m2 where
m1 = [ [ 1, 2, 3 ]
, [ 4, 5, 6 ]
] :: Matrix 2 3 Int
m2 = [ [ 9, 8 ]
, [ 6, 5 ]
, [ 3, 2 ]
] :: Matrix 3 2 Int
example3 :: Hyper '[] Int
example3 = innerH v1 v2 where
v1 = Prism (Scalar [1, 2, 3]) :: Hyper '[Vector 3] Int
v2 = Prism (Scalar [4, 5, 6]) :: Hyper '[Vector 3] Int
example4 :: Hyper '[Vector 2, Vector 2] Int
example4 = matrixH v1 v2 where
x = [ [ 1, 2, 3 ]
, [ 4, 5, 6 ]
] :: Matrix 2 3 Int
y = [ [ 9, 8 ]
, [ 6, 5 ]
, [ 3, 2 ]
] :: Matrix 3 2 Int
v1 = Prism (Prism (Scalar x)) :: Hyper '[Vector 3, Vector 2] Int
v2 = Prism (Prism (Scalar y)) :: Hyper '[Vector 2, Vector 3] Int
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment