-
-
Save yongqli/5dc94aeaeda2e24c6637 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Author: Yongqian Li | |
{-# OPTIONS_GHC -funfolding-use-threshold=2000 -funfolding-creation-threshold=2000 #-} | |
{-# LANGUAGE TemplateHaskell, RankNTypes, BangPatterns, ConstraintKinds, MultiParamTypeClasses, | |
FlexibleInstances, TypeOperators, GADTs, TypeFamilies, FlexibleContexts #-} | |
module Main where | |
import Control.Lens | |
import Data.Strict.Tuple (Pair(..), (:!:)) | |
import Linear | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Generic as G | |
import qualified Data.Vector.Unboxed as VU | |
import Data.Vector.Unboxed (Unbox) | |
import Data.Vector.Unboxed.Deriving (derivingUnbox) | |
type LinAlgNum f = (Epsilon f, Floating f, Unbox f) | |
-- | LinAlg represents a vector type we can do linear algebra in | |
class ( Traversable v, Additive v, Applicative v | |
, Floating a, Floating (v a), Floating (v (v a)) | |
, Unbox a, Unbox (v a), Unbox (v (v a)) ) | |
=> LinAlg v a where | |
inv :: v (v a) -> v (v a) | |
instance LinAlgNum f => LinAlg V1 f where | |
inv (V1 (V1 x)) = V1 (V1 (1/x)) | |
derivingUnbox "Pair" | |
[t| forall a b. (Unbox a, Unbox b) => a :!: b -> (a, b) |] | |
[| \(a :!: b) -> (a, b) |] | |
[| \(a, b) -> a :!: b |] | |
data MVN v a where | |
MVN :: LinAlg v a => !(v a) -> !(v (v a)) -> MVN v a | |
getμ (MVN _μ _) = _μ | |
opOn :: (LinAlg t a, LinAlg v a) => v (t a) -> MVN t a -> MVN v a | |
{-# INLINE opOn #-} | |
opOn op (MVN _μ _Σ) = MVN (op !* _μ) identity | |
derivingUnbox "MVN" | |
[t| forall v a. LinAlg v a => MVN v a -> (v a, v (v a)) |] | |
[| \(MVN _μ _Σ) -> (_μ, _Σ) |] | |
[| \(_μ, _Σ) -> MVN _μ _Σ |] | |
zipC :: (G.Vector vi1 a, G.Vector vi2 b) | |
=> vi1 a -> vi2 b -> V.Vector (a, b) | |
{-# INLINE zipC #-} | |
zipC va vb = | |
G.convert $ V.zip (G.convert va) (G.convert vb) | |
zip3C :: (G.Vector v1 a, G.Vector v2 b, G.Vector v3 c) | |
=> v1 a -> v2 b -> v3 c -> V.Vector (a, b, c) | |
{-# INLINE zip3C #-} | |
zip3C va vb vc = | |
V.zip3 (G.convert va) (G.convert vb) (G.convert vc) | |
ffwd :: forall s a m. ( LinAlg s a, LinAlg m a, VU.Unbox (m (s a)), (G.Vector VU.Vector (m (s a) :!: MVN m a)) ) | |
=> MVN s a | |
-> VU.Vector (s (s a)) | |
-> VU.Vector (MVN s a) | |
-> VU.Vector (m (s a) :!: MVN m a) | |
-> VU.Vector (MVN s a) | |
ffwd s0 ops errs xs = | |
let fwd prevS (op, err, x) = op `opOn` prevS | |
in G.convert $ G.scanl' fwd s0 (zip3C ops errs xs) | |
{-# INLINE ffwd #-} | |
runFfwd :: VU.Vector Double -> VU.Vector (V1 Double) | |
runFfwd rT = | |
G.map getμ $ ffwd | |
(MVN 0 identity) | |
(G.replicate 1 $ pure 0) | |
(G.replicate 1 $ MVN 0 identity) | |
(G.replicate 1 (V1 1 :!: MVN (V1 1) (V1 (V1 1)))) | |
{-# INLINE runFfwd #-} | |
run1 :: V1 (VU.Vector Double) | |
run1 = do | |
let | |
{-# INLINE[1] runFfwd2 #-} -- must be INLINE[1] rather than INLINE to trigger bug | |
runFfwd2 rT = V1 . G.map (^._x) $ runFfwd rT | |
runFfwd2 (VU.replicate 1 0) | |
{-# NOINLINE run1 #-} | |
main = do | |
print run1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment