Skip to content

Instantly share code, notes, and snippets.

@yangzhixuan
Created April 21, 2023 11:17
Show Gist options
  • Save yangzhixuan/2f47081f2ffe8eee285f4331c1d797dc to your computer and use it in GitHub Desktop.
Save yangzhixuan/2f47081f2ffe8eee285f4331c1d797dc to your computer and use it in GitHub Desktop.
Treap-based array
import Control.Monad (foldM)
import Prelude hiding (sum)
import Control.Monad.Random (MonadRandom(getRandom))
type Pos = Int
data TArray = E | N TArray -- Left subtree
Int -- Data payload (not position!)
Int -- Priority
Int -- Size of the tree
Int -- Sum of the tree
TArray -- Right subtree
size :: TArray -> Int
size E = 0
size (N _ _ _ s _ _) = s
sum :: TArray -> Int
sum E = 0
sum (N _ _ _ _ s _) = s
node :: TArray -> Int -> Int -> TArray -> TArray
node l d p r = N l d p (size l + 1 + size r) (sum l + d + sum r) r
toList :: TArray -> [Int]
toList t = go t [] where
go E k = k
go (N lt d p _ _ rt) k = go lt (d : go rt k)
merge :: TArray -> TArray -> TArray
merge E ys = ys
merge xs E = xs
merge xs@(N lxs x p _ _ rxs) ys@(N lys y q _ _ rys)
| p < q = node lxs x p (merge rxs ys)
| otherwise = node (merge xs lys) y q rys
-- Semantically, |split| is |take| and |drop| for |[a]|
split :: TArray -> Int -> (TArray, TArray)
split E _ = (E, E)
split (N lt x p _ _ rt) n
| size lt < n = let (rlt, rrt) = split rt (n - size lt - 1) in (node lt x p rlt, rrt)
| otherwise = let (llt, lrt) = split lt n in (llt, node lrt x p rt)
-- Split the tree into three segments: [0, x), [x, y), [y, inf)
split3 :: TArray -> Pos -> Pos -> (TArray, TArray, TArray)
split3 t x y = let (l, r) = split t x -- r represents [x, inf)
(rl, rr) = split r (y - x) -- rl represents [x, x + (y - x))
in (l, rl, rr)
insert :: MonadRandom m => TArray -> Pos -> Int -> m TArray
insert xs k d = do p <- getRandom; return (insert' xs k (d, p))
where insert' :: TArray -> Pos -> (Int, Int) -> TArray
insert' xs k (d, p) = let (l, r) = split xs k
in (l `merge` node E d p E) `merge` r
-- Remove things in the interval [x, y).
delete :: TArray -> Pos -> Pos -> TArray
delete t x y = let (l, m, r) = split3 t x y
in merge l r
-- Obtain things in the interval [x, y)
get :: TArray -> Pos -> Pos -> [Int]
get t x y = let (l, m, r) = split3 t x y
in toList m
-- Querying the sum of the range [x, y)
rsum :: TArray -> Pos -> Pos -> Int
rsum t x y = let (l, m, r) = split3 t x y
in sum m
test :: IO ()
test = do t1 <- foldM (\t v -> insert t 0 v) E (reverse [0 .. 10000])
print (rsum t1 0 100)
return ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment