Last active
August 29, 2015 14:26
-
-
Save wyager/df1809badc7c6a75cd5f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- A first exploration into ST freeze/thaw fusion. | |
-- will.yager@gmail.com | |
{-# LANGUAGE RankNTypes, MultiParamTypeClasses, BangPatterns #-} | |
import Control.Monad.ST (ST, runST) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Mutable as M | |
import Control.Monad.Primitive (PrimMonad, PrimState) | |
data Matrix a = Matrix {unMatrix :: V.Vector a} | |
data MMatrix s a = MMatrix (M.MVector s a) | |
class Freezable f t where | |
freeze :: PrimMonad m => t (PrimState m) a -> m (f a) | |
thaw :: PrimMonad m => f a -> m (t (PrimState m) a) | |
instance Freezable V.Vector M.MVector where | |
freeze = V.freeze | |
thaw = V.thaw | |
instance Freezable Matrix MMatrix where | |
freeze (MMatrix v) = fmap Matrix (freeze v) | |
thaw (Matrix v) = fmap MMatrix (thaw v) | |
runM :: (Freezable f t, PrimMonad m) => (t (PrimState m) a -> m ()) -> f a -> m (f a) | |
runM op m = do | |
m' <- thaw m | |
op m' | |
freeze m' | |
{-# RULES | |
"runFusion" forall | |
(op2 :: (forall s . t s a -> ST s ())) | |
(op1 :: (forall s . t s a -> ST s ())) | |
m . | |
run op2 (run op1 m) = run (\m' -> op1 m' >> op2 m') m | |
#-} | |
{-# NOINLINE run #-} | |
run :: (Freezable f t) => (forall s . t s a -> ST s ()) -> f a -> f a | |
run op m = runST (runM op m) | |
add1 :: Num a => Matrix a -> Matrix a | |
add1 = run add1ST | |
add1ST :: Num a => MMatrix s a -> ST s () | |
add1ST (MMatrix v) = add1ST' 0 (M.length v) | |
where | |
add1ST' !n !l | |
| n == l = return () | |
| otherwise = do | |
!elem <- M.read v n -- Using M.modify is lazy (which is slow here) | |
!elem' <- return (elem + 1) | |
M.write v n elem' | |
add1ST' (n+1) l | |
main = do | |
let huge = Matrix (V.replicate (1000*1000*10) 0) :: Matrix Int | |
print . V.sum . unMatrix $ (add1 . add1 . add1 . add1) huge |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment