Created
April 21, 2023 11:17
-
-
Save yangzhixuan/2f47081f2ffe8eee285f4331c1d797dc to your computer and use it in GitHub Desktop.
Treap-based array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Monad (foldM) | |
import Prelude hiding (sum) | |
import Control.Monad.Random (MonadRandom(getRandom)) | |
type Pos = Int | |
data TArray = E | N TArray -- Left subtree | |
Int -- Data payload (not position!) | |
Int -- Priority | |
Int -- Size of the tree | |
Int -- Sum of the tree | |
TArray -- Right subtree | |
size :: TArray -> Int | |
size E = 0 | |
size (N _ _ _ s _ _) = s | |
sum :: TArray -> Int | |
sum E = 0 | |
sum (N _ _ _ _ s _) = s | |
node :: TArray -> Int -> Int -> TArray -> TArray | |
node l d p r = N l d p (size l + 1 + size r) (sum l + d + sum r) r | |
toList :: TArray -> [Int] | |
toList t = go t [] where | |
go E k = k | |
go (N lt d p _ _ rt) k = go lt (d : go rt k) | |
merge :: TArray -> TArray -> TArray | |
merge E ys = ys | |
merge xs E = xs | |
merge xs@(N lxs x p _ _ rxs) ys@(N lys y q _ _ rys) | |
| p < q = node lxs x p (merge rxs ys) | |
| otherwise = node (merge xs lys) y q rys | |
-- Semantically, |split| is |take| and |drop| for |[a]| | |
split :: TArray -> Int -> (TArray, TArray) | |
split E _ = (E, E) | |
split (N lt x p _ _ rt) n | |
| size lt < n = let (rlt, rrt) = split rt (n - size lt - 1) in (node lt x p rlt, rrt) | |
| otherwise = let (llt, lrt) = split lt n in (llt, node lrt x p rt) | |
-- Split the tree into three segments: [0, x), [x, y), [y, inf) | |
split3 :: TArray -> Pos -> Pos -> (TArray, TArray, TArray) | |
split3 t x y = let (l, r) = split t x -- r represents [x, inf) | |
(rl, rr) = split r (y - x) -- rl represents [x, x + (y - x)) | |
in (l, rl, rr) | |
insert :: MonadRandom m => TArray -> Pos -> Int -> m TArray | |
insert xs k d = do p <- getRandom; return (insert' xs k (d, p)) | |
where insert' :: TArray -> Pos -> (Int, Int) -> TArray | |
insert' xs k (d, p) = let (l, r) = split xs k | |
in (l `merge` node E d p E) `merge` r | |
-- Remove things in the interval [x, y). | |
delete :: TArray -> Pos -> Pos -> TArray | |
delete t x y = let (l, m, r) = split3 t x y | |
in merge l r | |
-- Obtain things in the interval [x, y) | |
get :: TArray -> Pos -> Pos -> [Int] | |
get t x y = let (l, m, r) = split3 t x y | |
in toList m | |
-- Querying the sum of the range [x, y) | |
rsum :: TArray -> Pos -> Pos -> Int | |
rsum t x y = let (l, m, r) = split3 t x y | |
in sum m | |
test :: IO () | |
test = do t1 <- foldM (\t v -> insert t 0 v) E (reverse [0 .. 10000]) | |
print (rsum t1 0 100) | |
return () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment