Skip to content

Instantly share code, notes, and snippets.

@andrewthad
Created June 24, 2018 20:53
Show Gist options
  • Save andrewthad/46dde4c28b4ea48c18f9c31087282316 to your computer and use it in GitHub Desktop.
Save andrewthad/46dde4c28b4ea48c18f9c31087282316 to your computer and use it in GitHub Desktop.
Optimized Fixed-Length Vectors
{-# language BangPatterns #-}
{-# language KindSignatures #-}
{-# language LambdaCase #-}
{-# language DataKinds #-}
{-# language TypeFamilies #-}
{-# language GADTs #-}
{-# language ScopedTypeVariables #-}
module Fancy
( Vec(..)
, Optimal
, dot
-- * Construction
, fromVector
, singleton
, doubleton
, triple
, quadruple
) where
import Nat (Nat(..),SingNat(..),ImplicitNat(..),N1,N2,N3,N4)
import Data.Functor.Const (Const(..))
import Data.Functor.Identity (Identity(..))
import Data.Kind (Type)
import Linear.V0 (V0(..))
import Linear.V1 (V1(..))
import Linear.V2 (V2(..))
import Linear.V3 (V3(..))
import Linear.V4 (V4(..))
import qualified Simple
import qualified Vector as V
import qualified Linear.Metric
type family Optimal (n :: Nat) :: Type -> Type where
Optimal 'Zero = V0
Optimal ('Succ 'Zero) = V1
Optimal ('Succ ('Succ 'Zero)) = V2
Optimal ('Succ ('Succ ('Succ 'Zero))) = V3
Optimal ('Succ ('Succ ('Succ ('Succ 'Zero)))) = V4
Optimal ('Succ ('Succ ('Succ ('Succ ('Succ m))))) =
Simple.Vec ('Succ ('Succ ('Succ ('Succ ('Succ m)))))
newtype Vec n a = Vec (Optimal n a)
instance (ImplicitNat n, Eq a) => Eq (Vec n a) where
Vec x == Vec y = case (implicitNat :: SingNat n) of
SingZero -> x == y
SingSucc SingZero -> x == y
SingSucc (SingSucc SingZero) -> x == y
SingSucc (SingSucc (SingSucc SingZero)) -> x == y
SingSucc (SingSucc (SingSucc (SingSucc SingZero))) -> x == y
SingSucc (SingSucc (SingSucc (SingSucc (SingSucc _)))) -> x == y
instance (ImplicitNat n, Ord a) => Ord (Vec n a) where
compare (Vec x) (Vec y) = case (implicitNat :: SingNat n) of
SingZero -> compare x y
SingSucc SingZero -> compare x y
SingSucc (SingSucc SingZero) -> compare x y
SingSucc (SingSucc (SingSucc SingZero)) -> compare x y
SingSucc (SingSucc (SingSucc (SingSucc SingZero))) -> compare x y
SingSucc (SingSucc (SingSucc (SingSucc (SingSucc _)))) -> compare x y
fromVector :: V.Vector n a -> Vec n a
fromVector = \case
V.VectorNil -> Vec V0
V.VectorCons a V.VectorNil -> Vec (V1 a)
V.VectorCons a (V.VectorCons b V.VectorNil) -> Vec (V2 a b)
V.VectorCons a (V.VectorCons b (V.VectorCons c V.VectorNil)) -> Vec (V3 a b c)
V.VectorCons a (V.VectorCons b (V.VectorCons c (V.VectorCons d V.VectorNil))) -> Vec (V4 a b c d)
v@(V.VectorCons _ (V.VectorCons _ (V.VectorCons _ (V.VectorCons _ (V.VectorCons _ V.VectorNil))))) -> Vec (Simple.fromVector v)
singleton :: a -> Vec N1 a
singleton a = Vec (V1 a)
doubleton :: a -> a -> Vec N2 a
doubleton a b = Vec (V2 a b)
triple :: a -> a -> a -> Vec N3 a
triple a b c = Vec (V3 a b c)
quadruple :: a -> a -> a -> a -> Vec N4 a
quadruple a b c d = Vec (V4 a b c d)
dot :: forall n a. (Num a, ImplicitNat n) => Vec n a -> Vec n a -> a
dot (Vec x) (Vec y) = case (implicitNat :: SingNat n) of
SingZero -> Linear.Metric.dot x y
SingSucc SingZero -> Linear.Metric.dot x y
SingSucc (SingSucc SingZero) -> Linear.Metric.dot x y
SingSucc (SingSucc (SingSucc SingZero)) -> Linear.Metric.dot x y
SingSucc (SingSucc (SingSucc (SingSucc SingZero))) -> Linear.Metric.dot x y
SingSucc (SingSucc (SingSucc (SingSucc (SingSucc _)))) -> Simple.dot x y
{-# language BangPatterns #-}
{-# language KindSignatures #-}
{-# language DataKinds #-}
{-# language GADTs #-}
module Nat
( Nat(..)
, SingNat(..)
, ImplicitNat(..)
, N0
, N1
, N2
, N3
, N4
, N5
) where
import Data.Kind (Type)
data Nat = Succ !Nat | Zero
data SingNat :: Nat -> Type where
SingSucc :: !(SingNat n) -> SingNat ('Succ n)
SingZero :: SingNat 'Zero
class ImplicitNat (n :: Nat) where
implicitNat :: SingNat n
instance ImplicitNat 'Zero where
implicitNat = SingZero
instance ImplicitNat n => ImplicitNat ('Succ n) where
implicitNat = SingSucc implicitNat
type N0 = 'Zero
type N1 = 'Succ N0
type N2 = 'Succ N1
type N3 = 'Succ N2
type N4 = 'Succ N3
type N5 = 'Succ N4
{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DeriveFunctor #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
module Simple
( Vec
, dot
-- * Construction
, fromVector
, singleton
, doubleton
, triple
, quadruple
) where
import Nat
import Vector (Vector)
import Data.Primitive.SmallArray
import Control.Monad.ST (ST,runST)
import qualified Vector as V
newtype Vec (n :: Nat) a = Vec (SmallArray a)
deriving (Functor,Eq,Ord)
-- | Dot product of two vectors
dot :: Num a => Vec n a -> Vec n a -> a
dot (Vec xs) (Vec ys) =
let len = sizeofSmallArray xs
go !acc ix = if ix < len
then
let (# x #) = indexSmallArray## xs ix
(# y #) = indexSmallArray## ys ix
in go (x * y + acc) (ix + 1)
else acc
in go 0 0
fromVector :: forall n a. Vector n a -> Vec n a
fromVector v0 = Vec $ runST action where
action :: forall s. ST s (SmallArray a)
action = do
m <- newSmallArray (V.length v0) unitializedElement
let go :: forall m. Int -> Vector m a -> ST s ()
go !_ V.VectorNil = return ()
go !ix (V.VectorCons x xs) = do
writeSmallArray m ix x
go (ix + 1) xs
go 0 v0
unsafeFreezeSmallArray m
singleton :: a -> Vec N1 a
singleton a = Vec $ runST $ do
newSmallArray 1 a >>= unsafeFreezeSmallArray
doubleton :: a -> a -> Vec N2 a
doubleton a b = Vec $ runST $ do
m <- newSmallArray 2 a
writeSmallArray m 1 b
unsafeFreezeSmallArray m
triple :: a -> a -> a -> Vec N3 a
triple a b c = Vec $ runST $ do
m <- newSmallArray 3 a
writeSmallArray m 1 b
writeSmallArray m 2 c
unsafeFreezeSmallArray m
quadruple :: a -> a -> a -> a -> Vec N4 a
quadruple a b c d = Vec $ runST $ do
m <- newSmallArray 4 a
writeSmallArray m 1 b
writeSmallArray m 2 c
writeSmallArray m 3 d
unsafeFreezeSmallArray m
{-# NOINLINE unitializedElement #-}
unitializedElement :: a
unitializedElement = error "Simple.unitializedElement"
{-# language BangPatterns #-}
{-# language KindSignatures #-}
{-# language DataKinds #-}
{-# language GADTs #-}
module Vector
( Vector(..)
, length
) where
import Prelude hiding (length)
import Data.Kind (Type)
import Nat (Nat(Succ,Zero))
data Vector :: Nat -> Type -> Type where
VectorNil :: Vector 'Zero a
VectorCons :: a -> Vector n a -> Vector ('Succ n) a
length :: Vector n a -> Int
length = go 0 where
go :: Int -> Vector m b -> Int
go !n VectorNil = n
go !n (VectorCons _ xs) = go (n + 1) xs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment