Skip to content

Instantly share code, notes, and snippets.

@sorawee
Last active July 1, 2018 01:09
Show Gist options
  • Save sorawee/552cc671dcbad9b6255f033c21453ad1 to your computer and use it in GitHub Desktop.
Save sorawee/552cc671dcbad9b6255f033c21453ad1 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <vector>
using namespace std;
// returns 2^x such that 2^x >= n
inline int smallest_pow2(int n) {
int ret = 1;
while (ret < n) ret <<= 1;
return ret;
}
// returns number of intersecting cells of [a, b] and [x, y]
inline int intersect(int a, int b, int x, int y) {
return max(0, min(b, y) - max(a, x) + 1);
}
struct Node {
int l, r, val, lazy;
// it's actually possible to compute l and r from an index,
// so we could save 8 bytes per node if we really need it
};
struct SegmentTree {
vector<Node> nodes;
SegmentTree(int n) {
int offset = smallest_pow2(n);
nodes.resize(2*offset);
// could save space by resizing to only offset+n+1, but will need more checks in various places
// to prevent index out of bound
/*
E.g., if 5 <= n <= 8, we want to start at the index offset = smallest_pow2(n) = 8
1
2 3
4 5 6 7
here>> 8 9 10 11 12 13 14 15
*/
// set the leaves (the bottommost row)
for (int i = 0; i < offset; ++i) {
nodes[offset + i].l = i;
nodes[offset + i].r = i;
nodes[offset + i].val = 0; // if initial values are provided, can set them here (for i < n)
nodes[offset + i].lazy = 0; // in the context of summation, lazy = 0 means no lazy value
}
// set the inner nodes
for (int i = offset - 1; i >= 1; --i) {
nodes[i].l = nodes[i*2].l;
nodes[i].r = nodes[i*2 + 1].r;
nodes[i].val = nodes[i*2].val + nodes[i*2 + 1].val; // can just set to 0 if everything is initially 0
nodes[i].lazy = 0;
}
}
void propagate(int v) {
if (nodes[v].lazy == 0) return; // this line is simply to shortcut. It's not really needed...
nodes[v].val += nodes[v].lazy * (nodes[v].r - nodes[v].l + 1);
if (v*2 < int(nodes.size())) nodes[v*2].lazy += nodes[v].lazy;
if (v*2 + 1 < int(nodes.size())) nodes[v*2 + 1].lazy += nodes[v].lazy;
nodes[v].lazy = 0;
}
int query(int l, int r) {
return query_iter(1, l, r);
}
int query_iter(int v, int l, int r) {
if (not intersect(nodes[v].l, nodes[v].r, l, r)) return 0;
propagate(v);
if (l <= nodes[v].l and nodes[v].r <= r) return nodes[v].val;
return query_iter(v*2, l, r) + query_iter(v*2 + 1, l, r);
}
void update(int l, int r, int val) {
update_iter(1, l, r, val);
}
void update_iter(int v, int l, int r, int val) {
int intersecting_cells = intersect(nodes[v].l, nodes[v].r, l, r);
if (not intersecting_cells) return;
if (l <= nodes[v].l and nodes[v].r <= r) {
nodes[v].lazy += val; // only set lazy. No need to bother with val since propagation will deal with that.
return;
}
// can't update the entire node's lazy, so we need to maintain val explicitly
nodes[v].val += val * intersecting_cells;
update_iter(v*2, l, r, val);
update_iter(v*2 + 1, l, r, val);
}
};
int main() {
SegmentTree st(10);
// 0 0 0 0 0 0 0 0 0 0
st.update(2, 5, 3);
// 0 0 3 3 3 3 0 0 0 0
cout << st.query(0, 9) << endl; // expect 12
cout << st.query(1, 2) << endl; // expect 3
st.update(1, 3, 7);
// 0 7 10 10 3 3 0 0 0 0
st.update(5, 8, 1);
// 0 7 10 10 3 4 1 1 1 0
cout << st.query(2, 3) << endl; // expect 20
cout << st.query(1, 6) << endl; // expect 35
cout << st.query(0, 9) << endl; // expect 37
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment