Last active
September 27, 2022 23:24
-
-
Save m00nlight/5b6eeebc28ab15a35b10 to your computer and use it in GitHub Desktop.
Haskell segment tree with lazy propagation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Applicative | |
import Control.Monad | |
import qualified Data.ByteString.Char8 as BS | |
import Data.List | |
import Data.Maybe | |
import qualified Data.Vector as V | |
data SegTree a = | |
Node { | |
val :: a | |
, lazy :: Maybe a | |
, left, right :: Int | |
, leftChild, rightChild :: SegTree a | |
} | | |
Leaf { | |
val :: a | |
, lazy :: Maybe a | |
, left, right :: Int | |
} deriving (Show, Eq, Ord) | |
instance (Num a) => Num (Maybe a) where | |
Nothing + Nothing = Nothing | |
Nothing + (Just a) = Just a | |
(Just a) + Nothing = Just a | |
(Just a) + (Just b) = Just (a + b) | |
fromInteger = undefined | |
(*) = undefined | |
signum = undefined | |
abs = undefined | |
fI :: (Integral a, Num b) => a -> b | |
fI = fromIntegral | |
initTree :: (Num a) => [a] -> SegTree a | |
initTree xs = aux 0 (n - 1) | |
where | |
vs = V.fromList xs | |
n = V.length vs | |
aux l r | |
| l == r = | |
Leaf { val = vs V.! l, lazy = Nothing, left = l, right = r} | |
| otherwise = | |
let mid = (l + r) `div` 2 | |
lChild = aux l mid | |
rChild = aux (mid + 1) r | |
in Node { val = val lChild + val rChild | |
, left = l | |
, right = r | |
, lazy = Nothing | |
, leftChild = lChild, rightChild = rChild | |
} | |
updateNode :: (Num a, Eq a) => SegTree a -> SegTree a | |
updateNode rt = | |
if (lazy rt) == Nothing then | |
rt | |
else | |
let (lc, rc) = (leftChild rt, rightChild rt) | |
(l, r) = (left rt, right rt) | |
in if l == r then | |
rt { val = (val rt) + (fromJust $ lazy rt), lazy = Nothing } | |
else | |
let nlc = lc {lazy = (lazy lc) + (lazy rt)} | |
nrc = rc {lazy = (lazy rc) + (lazy rt)} | |
in rt { val = (val rt) + (fromJust $ lazy rt) * (fI $ r - l +1) | |
, leftChild = nlc | |
, rightChild = nrc | |
, lazy = Nothing | |
} | |
queryTree :: (Num a, Eq a) => SegTree a -> Int -> Int -> (a, SegTree a) | |
queryTree root l r | |
| l > r || r < left root || l > right root = (0, root) | |
| otherwise = | |
let nr = updateNode root | |
in if (left nr) >= l && (right nr) <= r then | |
(val nr, nr) | |
else | |
((fst $ queryTree (leftChild nr) l r) + | |
(fst $ queryTree (rightChild nr) l r), nr) | |
updateTree :: (Num a, Eq a) => SegTree a -> Int -> Int -> a -> SegTree a | |
updateTree root l r inc | |
| l > r = root | |
| otherwise = aux nr | |
where | |
nr = updateNode root | |
aux root | |
| (right root) < l || (left root) > r = root | |
| (left root) >= l && (right root) <= r = | |
let lc = leftChild root | |
rc = rightChild root | |
[a, b] = [left root, right root] | |
in | |
if a /= b then | |
root { val = (val root) + inc * (fI $ b - a + 1) | |
, leftChild = lc { lazy = (lazy lc) + (Just inc)} | |
, rightChild = rc { lazy = (lazy rc) + (Just inc) } | |
} | |
else | |
root { val = (val root) + inc * (fI $ b - a + 1)} | |
| otherwise = | |
let nlc = updateTree (leftChild root) l r inc | |
nrc = updateTree (rightChild root) l r inc | |
in root { val = (val nlc) + (val nrc) | |
, leftChild = nlc | |
, rightChild = nrc | |
} | |
readInt' :: BS.ByteString -> Int | |
readInt' = fst . fromJust . BS.readInt | |
readInteger' :: BS.ByteString -> Integer | |
readInteger' = fst . fromJust. BS.readInteger | |
solve :: Int -> [[Int]] -> (SegTree Integer, [Integer]) | |
solve n qs = foldl' | |
(\ (root, acc) q -> | |
if (head q) == 0 then | |
(updateTree root | |
(q !! 1 - 1) (q !! 2 - 1) (toInteger (q !! 3)), acc) | |
else | |
let (ans, nr) = queryTree root (q !! 1 - 1) (q !! 2 - 1) | |
in (nr, ans : acc) ) | |
(initTree $ take n (repeat 0), []) | |
qs | |
main :: IO () | |
main = do | |
tc <- readLn :: IO Int | |
forM_ [1..tc] $ \_ -> do | |
[n, q] <- map readInt' . BS.words <$> BS.getLine | |
contents <- replicateM q BS.getLine | |
let queries = map (\ x -> map readInt' (BS.words x)) contents | |
(_, ans) = solve n queries | |
putStrLn $ intercalate "\n" (map show $ reverse ans) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment