Skip to content

Instantly share code, notes, and snippets.

@yongqli

yongqli/Main.hs Secret

Created April 16, 2015 08:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yongqli/5dc94aeaeda2e24c6637 to your computer and use it in GitHub Desktop.
Save yongqli/5dc94aeaeda2e24c6637 to your computer and use it in GitHub Desktop.
-- Author: Yongqian Li
{-# OPTIONS_GHC -funfolding-use-threshold=2000 -funfolding-creation-threshold=2000 #-}
{-# LANGUAGE TemplateHaskell, RankNTypes, BangPatterns, ConstraintKinds, MultiParamTypeClasses,
FlexibleInstances, TypeOperators, GADTs, TypeFamilies, FlexibleContexts #-}
module Main where
import Control.Lens
import Data.Strict.Tuple (Pair(..), (:!:))
import Linear
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as VU
import Data.Vector.Unboxed (Unbox)
import Data.Vector.Unboxed.Deriving (derivingUnbox)
type LinAlgNum f = (Epsilon f, Floating f, Unbox f)
-- | LinAlg represents a vector type we can do linear algebra in
class ( Traversable v, Additive v, Applicative v
, Floating a, Floating (v a), Floating (v (v a))
, Unbox a, Unbox (v a), Unbox (v (v a)) )
=> LinAlg v a where
inv :: v (v a) -> v (v a)
instance LinAlgNum f => LinAlg V1 f where
inv (V1 (V1 x)) = V1 (V1 (1/x))
derivingUnbox "Pair"
[t| forall a b. (Unbox a, Unbox b) => a :!: b -> (a, b) |]
[| \(a :!: b) -> (a, b) |]
[| \(a, b) -> a :!: b |]
data MVN v a where
MVN :: LinAlg v a => !(v a) -> !(v (v a)) -> MVN v a
getμ (MVN _μ _) = _μ
opOn :: (LinAlg t a, LinAlg v a) => v (t a) -> MVN t a -> MVN v a
{-# INLINE opOn #-}
opOn op (MVN _μ _Σ) = MVN (op !* _μ) identity
derivingUnbox "MVN"
[t| forall v a. LinAlg v a => MVN v a -> (v a, v (v a)) |]
[| \(MVN _μ _Σ) -> (_μ, _Σ) |]
[| \(_μ, _Σ) -> MVN _μ _Σ |]
zipC :: (G.Vector vi1 a, G.Vector vi2 b)
=> vi1 a -> vi2 b -> V.Vector (a, b)
{-# INLINE zipC #-}
zipC va vb =
G.convert $ V.zip (G.convert va) (G.convert vb)
zip3C :: (G.Vector v1 a, G.Vector v2 b, G.Vector v3 c)
=> v1 a -> v2 b -> v3 c -> V.Vector (a, b, c)
{-# INLINE zip3C #-}
zip3C va vb vc =
V.zip3 (G.convert va) (G.convert vb) (G.convert vc)
ffwd :: forall s a m. ( LinAlg s a, LinAlg m a, VU.Unbox (m (s a)), (G.Vector VU.Vector (m (s a) :!: MVN m a)) )
=> MVN s a
-> VU.Vector (s (s a))
-> VU.Vector (MVN s a)
-> VU.Vector (m (s a) :!: MVN m a)
-> VU.Vector (MVN s a)
ffwd s0 ops errs xs =
let fwd prevS (op, err, x) = op `opOn` prevS
in G.convert $ G.scanl' fwd s0 (zip3C ops errs xs)
{-# INLINE ffwd #-}
runFfwd :: VU.Vector Double -> VU.Vector (V1 Double)
runFfwd rT =
G.map getμ $ ffwd
(MVN 0 identity)
(G.replicate 1 $ pure 0)
(G.replicate 1 $ MVN 0 identity)
(G.replicate 1 (V1 1 :!: MVN (V1 1) (V1 (V1 1))))
{-# INLINE runFfwd #-}
run1 :: V1 (VU.Vector Double)
run1 = do
let
{-# INLINE[1] runFfwd2 #-} -- must be INLINE[1] rather than INLINE to trigger bug
runFfwd2 rT = V1 . G.map (^._x) $ runFfwd rT
runFfwd2 (VU.replicate 1 0)
{-# NOINLINE run1 #-}
main = do
print run1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment