Skip to content

Instantly share code, notes, and snippets.

@tmcdonell
Last active October 18, 2018 12:53
Show Gist options
  • Save tmcdonell/eceeffe33eecd4be08fd5ed85301a82f to your computer and use it in GitHub Desktop.
Save tmcdonell/eceeffe33eecd4be08fd5ed85301a82f to your computer and use it in GitHub Desktop.
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Bits as A
import Data.Array.Accelerate.Interpreter as I
import Prelude as P
-- simple test
test = gather (segmentedIndices info) vec
segs :: Segments Int
segs = fromList (Z:.4) [1,7,0,2]
starts :: Vector Int
starts = fromList (Z:.4) [0,10,5,5]
info :: Acc (Vector (Int,Int))
info = A.zip (use starts) (use segs)
vec :: Acc (Vector Float)
vec = use $ fromList (Z:.20) [0..]
-- implementation
segmentedIndices
:: Acc (Vector (Int,Int))
-> Acc (Vector Int)
segmentedIndices info
= A.map A.snd
$ A.scanl1 (segmented (+)) (A.zip idx ones)
where
idx = mkHeadIndices info
ones = fill (shape idx) 1
mkHeadIndices
:: Acc (Vector (Int,Int))
-> Acc (Vector Int)
mkHeadIndices info
= A.init
$ A.permute const zeros (\ix -> seg!ix A.== 0 ? ( ignore, index1 (offset ! ix) )) start
where
(start, seg) = A.unzip info
(offset, len) = unlift (scanl' (+) 0 seg)
zeros = fill (index1 $ the len + 1) 0
segmented
:: (Exp Int -> Exp Int -> Exp Int)
-> Exp (Int, Int)
-> Exp (Int, Int)
-> Exp (Int, Int)
segmented f a b =
let (aF, aV) = unlift a
(bF, bV) = unlift b
in
lift ( aF A..|. bF
, bF A./= 0 ? (bF, f aV bV))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment