Skip to content

Instantly share code, notes, and snippets.

@yurahuna
Last active April 11, 2018 14:56
Show Gist options
  • Save yurahuna/c852cd78148497600c2fc88f220eeae3 to your computer and use it in GitHub Desktop.
Save yurahuna/c852cd78148497600c2fc88f220eeae3 to your computer and use it in GitHub Desktop.
抽象化半開区間セグツリー
template<class T, class E>
struct SegTree{
typedef function<T(T,T)> F;
typedef function<T(T,E)> G;
typedef function<E(E,E)> H;
typedef function<E(E,int)> P;
int n;
F f;
G g;
H h;
P p;
T d1;
E d0;
vector<T> dat;
vector<E> lazy;
SegTree(int n_, F f, G g, H h, T d1, E d0, P p=[](E a, int b){ return a; })
: f(f), g(g), h(h), d1(d1), d0(d0), p(p) {
n = 1;
while (n < n_) n *= 2;
dat.resize(2 * n - 1, d1);
lazy.resize(2 * n - 1, d0);
}
void build(const vector<E>& a) {
for (int i = 0; i < a.size(); i++) dat[i + n - 1] = a[i];
for (int i = n - 2; i >= 0; i--) dat[i] = f(dat[2 * i + 1], dat[2 * i + 2]);
}
void eval(int k, int l, int r) {
if (lazy[k] == d0) return;
dat[k] = g(dat[k], p(lazy[k], r - l));
if (k < n - 1) {
lazy[2 * k + 1] = h(lazy[2 * k + 1], lazy[k]);
lazy[2 * k + 2] = h(lazy[2 * k + 2], lazy[k]);
}
lazy[k] = d0;
}
void update(int a, int b, E x, int k = 0, int l = 0, int r = -1) {
if (r == -1) r = n;
// cerr << "update: a = " << a << ", b = " << b << ", x = " << x << ", k = " << k << ", l = " << l << ", r = " << r << endl;
eval(k, l, r);
if (b <= l || r <= a) return;
if (a <= l && r <= b) {
lazy[k] = h(lazy[k], x);
eval(k, l, r);
}
else {
update(a, b, x, 2 * k + 1, l, (l + r) / 2);
update(a, b, x, 2 * k + 2, (l + r) / 2, r);
dat[k] = f(dat[2 * k + 1], dat[2 * k + 2]);
}
}
// return op[a..b)
// k: node id, [l, r): node interval
T query(int a, int b, int k = 0, int l = 0, int r = -1) {
if (r == -1) r = n;
// cerr << "query: a = " << a << ", b = " << b << ", k = " << k << ", l = " << l << ", r = " << r << endl;
eval(k, l, r);
if (r <= a || b <= l) return d1;
if (a <= l && r <= b) return dat[k];
T vl = query(a, b, 2 * k + 1, l, (l + r) / 2);
T vr = query(a, b, 2 * k + 2, (l + r) / 2, r);
return f(vl, vr);
}
void print() {
int j = 2;
for (int i = 0; i < 2 * n - 1; i++) {
cerr << "(" << dat[i] << ", " << lazy[i] << ")";
if (i == j - 2) { cerr << endl; j *= 2; }
else { cerr << " "; }
}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment