Skip to content

Instantly share code, notes, and snippets.

@dimchansky
Created November 6, 2013 21:37
Show Gist options
  • Save dimchansky/7344543 to your computer and use it in GitHub Desktop.
Save dimchansky/7344543 to your computer and use it in GitHub Desktop.
Sparse matrix-vector multiplication in Haskell
{-# LANGUAGE ParallelArrays #-}
{-# OPTIONS -fvectorise #-}
module SMVMVectorised (smvmPA) where
import Data.Array.Parallel.Prelude
import Data.Array.Parallel.Prelude.Double as D
import Data.Array.Parallel.Prelude.Int as I
import qualified Prelude as P
smvmPA :: PArray (PArray (Int, Double)) -> PArray Double -> PArray Double
{-# NOINLINE smvmPA #-}
smvmPA m v = toPArrayP (smvm (fromNestedPArrayP m) (fromPArrayP v))
smvm :: [:[: (Int, Double) :]:] -> [:Double:] -> [:Double:]
smvm m v = [: D.sumP [: x D.* (v !: i) | (i,x) <- row :] | row <- m :]
{-# LANGUAGE RecordWildCards #-}
import Data.Vector.Unboxed as U
-- | A compressed row storage (CRS) sparse matrix.
data CRS a = CRS {
crsValues :: Vector a
, colIndices :: Vector Int
, rowIndices :: Vector Int
} deriving (Show)
multiplyMV :: CRS Double -> Vector Double -> Vector Double
multiplyMV CRS{..} x = generate (U.length x) outer
where outer i = U.sum . U.map inner $ U.enumFromN start (end-start)
where inner j = (crsValues ! j) * (x ! (colIndices ! j))
start = rowIndices ! i
end = rowIndices ! (i+1)
(!) a b = unsafeIndex a b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment