Last active
October 18, 2018 12:53
-
-
Save tmcdonell/eceeffe33eecd4be08fd5ed85301a82f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Data.Array.Accelerate as A | |
import Data.Array.Accelerate.Data.Bits as A | |
import Data.Array.Accelerate.Interpreter as I | |
import Prelude as P | |
-- simple test | |
test = gather (segmentedIndices info) vec | |
segs :: Segments Int | |
segs = fromList (Z:.4) [1,7,0,2] | |
starts :: Vector Int | |
starts = fromList (Z:.4) [0,10,5,5] | |
info :: Acc (Vector (Int,Int)) | |
info = A.zip (use starts) (use segs) | |
vec :: Acc (Vector Float) | |
vec = use $ fromList (Z:.20) [0..] | |
-- implementation | |
segmentedIndices | |
:: Acc (Vector (Int,Int)) | |
-> Acc (Vector Int) | |
segmentedIndices info | |
= A.map A.snd | |
$ A.scanl1 (segmented (+)) (A.zip idx ones) | |
where | |
idx = mkHeadIndices info | |
ones = fill (shape idx) 1 | |
mkHeadIndices | |
:: Acc (Vector (Int,Int)) | |
-> Acc (Vector Int) | |
mkHeadIndices info | |
= A.init | |
$ A.permute const zeros (\ix -> seg!ix A.== 0 ? ( ignore, index1 (offset ! ix) )) start | |
where | |
(start, seg) = A.unzip info | |
(offset, len) = unlift (scanl' (+) 0 seg) | |
zeros = fill (index1 $ the len + 1) 0 | |
segmented | |
:: (Exp Int -> Exp Int -> Exp Int) | |
-> Exp (Int, Int) | |
-> Exp (Int, Int) | |
-> Exp (Int, Int) | |
segmented f a b = | |
let (aF, aV) = unlift a | |
(bF, bV) = unlift b | |
in | |
lift ( aF A..|. bF | |
, bF A./= 0 ? (bF, f aV bV)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment