Last active
May 18, 2019 12:51
-
-
Save tjkendev/0965bdcf9e3c1063a7f2e76005d285ca to your computer and use it in GitHub Desktop.
Segment Tree Beatsの例題実装 (Task 3, Task 4まわり)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats | |
// - l<=i<r について、 a_i の値を min(a_i, x) に更新 | |
// - l<=i<r の中の a_i の最大値を求める | |
// - l<=i<r の a_i の和を求める | |
#define N 10003 | |
class SegmentTree { | |
const ll inf = 1e18; | |
int n, n0; | |
ll max_v[4*N], smax_v[4*N], max_c[4*N]; | |
ll ladd[4*N]; // 区間加算クエリにおけるlazy tag | |
ll len[4*N]; // ノードの区間に含まれる要素の数 | |
ll sum_a[4*N]; // 最大値の総和 | |
ll sum_b[4*N]; // 非最大値の総和 | |
void update_node_max(int k, ll x) { | |
// 最大値の総和だけ加算 | |
sum_a[k] += (x - max_v[k]) * max_c[k]; | |
max_v[k] = x; | |
} | |
void addall(int k, ll x) { | |
max_v[k] += x; smax_v[k] += x; | |
// 最大値の総和・非最大値の総和それぞれに加算 | |
sum_a[k] += x * max_c[k]; | |
sum_b[k] += x * (len[k] - max_c[k]); | |
ladd[k] += x; | |
} | |
void push(int k) { | |
if(ladd[k]) { | |
addall(2*k+1, ladd[k]); | |
addall(2*k+2, ladd[k]); | |
ladd[k] = 0; | |
} | |
if(max_v[k] < max_v[2*k+1]) { | |
update_node_max(2*k+1, max_v[k]); | |
} | |
if(max_v[k] < max_v[2*k+2]) { | |
update_node_max(2*k+2, max_v[k]); | |
} | |
} | |
void update(int k) { | |
// 最大値の総和と非最大値の総和を子ノードから更新 | |
sum_b[k] = sum_b[2*k+1] + sum_b[2*k+2]; | |
if(max_v[2*k+1] < max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+2]; | |
max_c[k] = max_c[2*k+2]; | |
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]); | |
sum_a[k] = sum_a[2*k+2]; | |
sum_b[k] += sum_a[2*k+1]; | |
} else if(max_v[2*k+1] > max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]); | |
sum_a[k] = sum_a[2*k+1]; | |
sum_b[k] += sum_a[2*k+2]; | |
} else { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1] + max_c[2*k+2]; | |
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]); | |
sum_a[k] = sum_a[2*k+1] + sum_a[2*k+2]; | |
} | |
} | |
void _update_min(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_v[k] <= x) { | |
return; | |
} | |
if(a <= l && r <= b && smax_v[k] < x) { | |
update_node_max(k, x); | |
return; | |
} | |
push(k); | |
_update_min(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_min(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
ll _query_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
return max_v[k]; | |
} | |
push(k); | |
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
ll _query_sum(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
// 最大値の総和・非最大値の総和を合わせる | |
return sum_a[k] + sum_b[k]; | |
} | |
push(k); | |
ll lv = _query_sum(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_sum(a, b, 2*k+2, (l+r)/2, r); | |
return lv + rv; | |
} | |
public: | |
SegmentTree(int n, ll *a) : n(n) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
len[0] = n0; | |
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
for(int i=0; i<n; ++i) { | |
max_v[n0-1+i] = sum_a[n0-1+i] = a[i]; | |
smax_v[n0-1+i] = -inf; sum_b[n0-1+i] = 0; | |
max_c[n0-1+i] = 1; | |
} | |
for(int i=n; i<n0; ++i) { | |
max_v[n0-1+i] = smax_v[n0-1+i] = -inf; | |
sum_a[n0-1+i] = sum_b[n0-1+i] = max_c[n0-1+i] = 0; | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
void update_min(int a, int b, ll x) { | |
return _update_min(x, a, b, 0, 0, n0); | |
} | |
ll query_max(int a, int b) { | |
return _query_max(a, b, 0, 0, n0); | |
} | |
ll query_sum(int a, int b) { | |
return _query_sum(a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats (Task 3) | |
// - l<=i<r について a_i の値を min(a_i, x) に更新 | |
// - l<=i<r について a_i の値を max(a_i, x) に更新 | |
// - l<=i<r について a_i の値を a_i+x に更新 | |
// - l<=i<r の b_i の和を求める | |
// | |
// 各クエリで、a_i が変化した場合に b_i に1加算する | |
#define N 10003 | |
class SegmentTree { | |
const ll inf = 1e18; | |
int n, n0; | |
ll max_v[4*N], smax_v[4*N], max_c[4*N]; | |
ll min_v[4*N], smin_v[4*N], min_c[4*N]; | |
ll len[4*N]; // len[k]: ノードkの区間に含まれる要素数 | |
ll ladd[4*N]; // ladd[k]: RAQによって発生する、区間の a_i に加算するためのlazy tag | |
ll sum_b[4*N]; // b_i の区間総和 | |
ll min_a[4*N]; // a_i が最小値の b_i に加算するlazy tag (区間chmaxクエリ用) | |
ll max_a[4*N]; // a_i が最大値の b_i に加算するlazy tag (区間chminクエリ用) | |
ll all_a[4*N]; // 区間内の b_i 全てに加算するlazy tag (区間加算クエリ用) | |
void update_node_max(int k, ll x) { | |
if(max_v[k] == min_v[k]) { | |
max_v[k] = min_v[k] = x; | |
} else if(max_v[k] == smin_v[k]) { | |
max_v[k] = smin_v[k] = x; | |
} else { | |
max_v[k] = x; | |
} | |
} | |
void update_node_min(int k, ll x) { | |
if(max_v[k] == min_v[k]) { | |
max_v[k] = min_v[k] = x; | |
} else if(smax_v[k] == min_v[k]) { | |
min_v[k] = smax_v[k] = x; | |
} else { | |
min_v[k] = x; | |
} | |
} | |
// 親ノードpの最大値・最小値の状態を元に子ノードkの b_i の区間総和を更新 | |
// lazy valueの伝搬 | |
void add_b(int p, int k, ll mi, ll ma, ll aa) { | |
// 親ノードpの最小値とノードkの最小値が等しい場合のみ更新と伝搬を行う | |
if(p == -1 || min_v[p] == min_v[k]) { | |
// S_B への加算 | |
sum_b[k] += mi * min_c[k]; | |
min_a[k] += mi; | |
} | |
// 親ノードpの最大値とノードkの最大値が等しい場合のみ更新と伝搬を行う | |
if(p == -1 || max_v[p] == max_v[k]) { | |
// S_A への加算 | |
sum_b[k] += ma * max_c[k]; | |
max_a[k] += ma; | |
} | |
// 全体(S_A + S_B + S_C)に加算する値の伝搬 | |
sum_b[k] += aa * len[k]; | |
all_a[k] += aa; | |
} | |
void addall(int k, ll x) { | |
max_v[k] += x; smax_v[k] += x; | |
min_v[k] += x; smin_v[k] += x; | |
ladd[k] += x; | |
} | |
void push(int k) { | |
if(ladd[k]) { | |
addall(2*k+1, ladd[k]); | |
addall(2*k+2, ladd[k]); | |
ladd[k] = 0; | |
} | |
if(max_v[k] < max_v[2*k+1]) { | |
update_node_max(2*k+1, max_v[k]); | |
} | |
if(min_v[2*k+1] < min_v[k]) { | |
update_node_min(2*k+1, min_v[k]); | |
} | |
if(max_v[k] < max_v[2*k+2]) { | |
update_node_max(2*k+2, max_v[k]); | |
} | |
if(min_v[2*k+2] < min_v[k]) { | |
update_node_min(2*k+2, min_v[k]); | |
} | |
// 子ノードの最小値/最大値を更新した後に区間総和を更新 | |
if(min_a[k] != 0 || max_a[k] != 0 || all_a[k] != 0) { | |
add_b(k, 2*k+1, min_a[k], max_a[k], all_a[k]); | |
add_b(k, 2*k+2, min_a[k], max_a[k], all_a[k]); | |
min_a[k] = max_a[k] = all_a[k] = 0; | |
} | |
} | |
void update(int k) { | |
// b_i の区間総和の値を子ノードから計算し直す | |
sum_b[k] = sum_b[2*k+1] + sum_b[2*k+2]; | |
if(max_v[2*k+1] < max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+2]; | |
max_c[k] = max_c[2*k+2]; | |
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]); | |
} else if(max_v[2*k+1] > max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]); | |
} else { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1] + max_c[2*k+2]; | |
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]); | |
} | |
if(min_v[2*k+1] < min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1]; | |
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]); | |
} else if(min_v[2*k+1] > min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+2]; | |
min_c[k] = min_c[2*k+2]; | |
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]); | |
} else { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1] + min_c[2*k+2]; | |
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]); | |
} | |
} | |
void _update_min(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_v[k] <= x) { | |
return; | |
} | |
if(a <= l && r <= b && smax_v[k] < x) { | |
add_b(-1, k, 0, 1, 0); | |
update_node_max(k, x); | |
return; | |
} | |
push(k); | |
_update_min(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_min(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
void _update_max(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || x <= min_v[k]) { | |
return; | |
} | |
if(a <= l && r <= b && x < smin_v[k]) { | |
add_b(-1, k, 1, 0, 0); | |
update_node_min(k, x); | |
return; | |
} | |
push(k); | |
_update_max(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_max(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
void _add_val(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
add_b(-1, k, 0, 0, 1); | |
addall(k, x); | |
return; | |
} | |
push(k); | |
_add_val(x, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
ll _query_sum(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
return sum_b[k]; | |
} | |
push(k); | |
ll lv = _query_sum(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_sum(a, b, 2*k+2, (l+r)/2, r); | |
return lv + rv; | |
} | |
public: | |
SegmentTree(int n, ll *a) : n(n) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
len[0] = n0; | |
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
for(int i=0; i<2*n0-1; ++i) { | |
min_a[i] = max_a[i] = all_a[i] = 0; | |
sum_b[i] = 0; | |
ladd[i] = 0; | |
} | |
for(int i=0; i<n; ++i) { | |
max_v[n0-1+i] = min_v[n0-1+i] = a[i]; | |
smax_v[n0-1+i] = -inf; | |
smin_v[n0-1+i] = inf; | |
max_c[n0-1+i] = min_c[n0-1+i] = 1; | |
} | |
for(int i=n; i<n0; ++i) { | |
max_v[n0-1+i] = smax_v[n0-1+i] = -inf; | |
min_v[n0-1+i] = smin_v[n0-1+i] = inf; | |
max_c[n0-1+i] = min_c[n0-1+i] = 0; | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// 区間[a, b) について a_i の値を min(a_i, x) に更新 | |
void update_min(int a, int b, ll x) { | |
return _update_min(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について a_i の値を max(a_i, x) に更新 | |
void update_max(int a, int b, ll x) { | |
return _update_max(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について a_i の値を a_i+x に更新 | |
void add_val(int a, int b, ll x) { | |
if(!x) return; | |
_add_val(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) の b_i の総和を求める | |
ll query_sum(int a, int b) { | |
return _query_sum(a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats | |
// - l<=i<r について A_i の値を min(A_i, x) に更新 | |
// - l<=i<r について B_i の値を x に更新する | |
// - l<=i<r の中の A_i + B_i の最大値を求める | |
#define N 10003 | |
class SegmentTree { | |
const static ll inf = 1e18; | |
struct Pair { | |
ll a, b; | |
ll sum() const { return a + b; } | |
// Pair同士を比較する時は a_i + b_i の値が大きい方を大きいとする | |
bool operator<(const Pair &other) const { | |
return this->sum() < other.sum(); | |
} | |
}; | |
int n, n0; | |
// a_i の最大値、二番目の最大値 | |
ll max_v[4*N], smax_v[4*N]; | |
// 遅延させている b_i の更新値 | |
ll lval[4*N]; | |
// a_i が最大値となる中で a_i + b_i が最大値となるペア P_A = (a_i, b_i) | |
Pair max_p[4*N]; | |
// a_i が非最大値となる中で a_i + b_i が最大値となるペア P_B = (a_i, b_i) | |
Pair nmax_p[4*N]; | |
void update_node_max(int k, ll x) { | |
// ペアP_A の a_i の値を x に更新 | |
// ペアP_B の a_i は最大値未満であるため更新は必要ない | |
max_p[k].a = x; | |
max_v[k] = x; | |
} | |
void update_val(int k, ll x) { | |
// ペアP_A, P_Bに含まれる b_i の値を x に更新 | |
max_p[k].b = x; | |
if(nmax_p[k].b != -inf) { | |
nmax_p[k].a = smax_v[k]; | |
nmax_p[k].b = x; | |
} | |
lval[k] = x; | |
} | |
// 親ノードから子ノードに情報を伝搬 | |
void push(int k) { | |
// a_i の更新 | |
if(max_v[k] < max_v[2*k+1]) { | |
update_node_max(2*k+1, max_v[k]); | |
} | |
if(max_v[k] < max_v[2*k+2]) { | |
update_node_max(2*k+2, max_v[k]); | |
} | |
// b_i の更新 | |
if(lval[k] != inf) { | |
update_val(2*k+1, lval[k]); | |
update_val(2*k+2, lval[k]); | |
lval[k] = inf; | |
} | |
} | |
// 子ノードから親ノードの情報を更新 | |
void update(int k) { | |
// 2*k+1, 2*k+2 の非最大値のうちの最大値を 非最大値として更新 | |
nmax_p[k] = max(nmax_p[2*k+1], nmax_p[2*k+2]); | |
if(max_v[2*k+1] < max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+2]; | |
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]); | |
// 2*k+2 は最大値、2*k+1 は非最大値 として更新 | |
max_p[k] = max_p[2*k+2]; | |
nmax_p[k] = max(nmax_p[k], max_p[2*k+1]); | |
} else if(max_v[2*k+1] > max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]); | |
// 2*k+1 は最大値、2*k+2 は非最大値 として更新 | |
max_p[k] = max_p[2*k+1]; | |
nmax_p[k] = max(nmax_p[k], max_p[2*k+2]); | |
} else { | |
max_v[k] = max_v[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]); | |
// 2*k+1, 2*k+2 共に最大値 として更新 | |
max_p[k] = max(max_p[2*k+1], max_p[2*k+2]); | |
} | |
} | |
// 区間[a, b)に含まれる a_i の値を min(a_i, x) に更新 | |
void _update_min(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_v[k] <= x) { | |
return; | |
} | |
if(a <= l && r <= b && smax_v[k] < x) { | |
update_node_max(k, x); | |
return; | |
} | |
push(k); | |
_update_min(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_min(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b)に含まれる b_i の値を x に更新 | |
void _update_val(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
update_val(k, x); | |
return; | |
} | |
push(k); | |
_update_val(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_val(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b)の中の a_i + b_i の最大値を求める | |
ll _query_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
// ペアP_A, P_Bのうち、a_i + b_i が最大な方の値を返す | |
return max(max_p[k], nmax_p[k]).sum(); | |
} | |
push(k); | |
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
public: | |
SegmentTree(int n, ll *a, ll *b) : n(n) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
for(int i=0; i<2*n0-1; ++i) lval[i] = inf; | |
for(int i=0; i<n; ++i) { | |
max_v[n0-1+i] = a[i]; | |
smax_v[n0-1+i] = -inf; | |
// [i:i+1)の最大値ペアは (a_i, b_i) | |
max_p[n0-1+i] = Pair{a[i], b[i]}; | |
// [i:i+1)の2番目の最大値ペアは (-∞, -∞) | |
nmax_p[n0-1+i] = Pair{-inf, -inf}; | |
} | |
for(int i=n; i<n0; ++i) { | |
max_v[n0-1+i] = smax_v[n0-1+i] = -inf; | |
max_p[n0-1+i] = nmax_p[n0-1+i] = Pair{-inf, -inf}; | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// 区間[a, b) について a_i の値を min(a_i, x) に更新 | |
void update_min(int a, int b, ll x) { | |
return _update_min(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について b_i の値を x に更新 | |
void update_val(int a, int b, ll x) { | |
_update_val(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) の中で max(a_i + b_i) を求める | |
ll query_max(int a, int b) { | |
return _query_max(a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats (Task 4) | |
// - l<=i<r について、 A_i の値を min(A_i, x) に更新 | |
// - l<=i<r について、 B_i の値を min(B_i, x) に更新 | |
// - l<=i<r について、 A_i の値を A_i + x に更新 | |
// - l<=i<r について、 B_i の値を B_i + x に更新 | |
// - l<=i<r の A_i + B_i の最大値を求める | |
#define N 10003 | |
class SegmentTree { | |
const static ll inf = 1e18; | |
struct Pair { | |
ll a, b; | |
ll sum() const { return a+b; } | |
bool operator<(const Pair& other) const { | |
return this->sum() < other.sum(); | |
} | |
}; | |
int n, n0; | |
// == ノードkにおける4種類のペアの情報 | |
// max_p[k][1][1]: a_i と b_i が最大値となる中で a_i+b_i が最大となるペア P_{11} = (a_i, b_i) | |
// max_p[k][1][0]: a_i が最大値、 b_i が非最大値となる中で a_i+b_i が最大となるペア P_{10} = (a_i, b_i) | |
// max_p[k][0][1]: a_i が非最大値、 b_i が最大値となる中で a_i+b_i が最大となるペア P_{01} = (a_i, b_i) | |
// max_p[k][0][0]: a_i と b_i が非最大値となる中で a_i+b_i が最大となるペア P_{00} = (a_i, b_i) | |
Pair max_p[4*N][2][2]; | |
// 数列AとBに関する最大値、二番目の最大値 | |
ll max_va[4*N], smax_va[4*N]; | |
ll max_vb[4*N], smax_vb[4*N]; | |
// 区間加算クエリで 遅延させている加算値 | |
ll ladd_a[4*N], ladd_b[4*N]; | |
// ノードの a_i, b_i の区間最大値を親ノードから更新する | |
void update_node_max(int k, ll xa, ll xb) { | |
if(xa < max_va[k]) { | |
// ペア P_{11}, P_{10} の a_i の値を更新 | |
if(max_p[k][1][1].a != -inf) max_p[k][1][1].a = xa; | |
if(max_p[k][1][0].a != -inf) max_p[k][1][0].a = xa; | |
max_va[k] = xa; | |
} | |
if(xb < max_vb[k]) { | |
// ペア P_{11}, P_{01} の b_i の値を更新 | |
if(max_p[k][1][1].b != -inf) max_p[k][1][1].b = xb; | |
if(max_p[k][0][1].b != -inf) max_p[k][0][1].b = xb; | |
max_vb[k] = xb; | |
} | |
} | |
// 区間内の a_i に xa を加算し、区間内の b_i に xb を加算する | |
inline void addall(int k, ll xa, ll xb) { | |
if(xa != 0) { | |
// 4種類のペア全ての a_i の値に加算 | |
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) { | |
if(max_p[k][i][j].a != -inf) max_p[k][i][j].a += xa; | |
} | |
max_va[k] += xa; | |
if(smax_va[k] != -inf) smax_va[k] += xa; | |
ladd_a[k] += xa; | |
} | |
if(xb != 0) { | |
// 4種類のペア全ての b_i の値に加算 | |
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) { | |
if(max_p[k][i][j].b != -inf) max_p[k][i][j].b += xb; | |
} | |
max_vb[k] += xb; | |
if(smax_vb[k] != -inf) smax_vb[k] += xb; | |
ladd_b[k] += xb; | |
} | |
} | |
// pushdown: 親ノードから子ノードへの伝搬 | |
void push(int k) { | |
if(ladd_a[k] != 0 || ladd_b[k] != 0) { | |
addall(2*k+1, ladd_a[k], ladd_b[k]); | |
addall(2*k+2, ladd_a[k], ladd_b[k]); | |
ladd_a[k] = ladd_b[k] = 0; | |
} | |
update_node_max(2*k+1, max_va[k], max_vb[k]); | |
update_node_max(2*k+2, max_va[k], max_vb[k]); | |
} | |
// ノードk の区間最大値と2番目の区間最大値の情報を更新 | |
inline void _update_max_v(int k, ll *max_v, ll *smax_v) { | |
if(max_v[2*k+1] < max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+2]; | |
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]); | |
} else if(max_v[2*k+1] > max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]); | |
} else { | |
max_v[k] = max_v[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]); | |
} | |
} | |
// update: 子ノードから親ノードの情報を更新 | |
void update(int k) { | |
// a_i と b_i の区間最大値の情報を更新 | |
_update_max_v(k, max_va, smax_va); | |
_update_max_v(k, max_vb, smax_vb); | |
// 更新する前に、親ノードの4種類のペアを一度初期化 | |
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) { | |
max_p[k][i][j] = Pair{-inf, -inf}; | |
} | |
// 左子ノード2*k+1 から 親ノードk の情報を更新 | |
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) { | |
Pair &p = max_p[2*k+1][i][j]; | |
Pair &e = max_p[k][p.a == max_va[k]][p.b == max_vb[k]]; | |
e = max(e, p); | |
} | |
// 右子ノード2*k+2 から 親ノードk の情報を更新 | |
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) { | |
Pair &p = max_p[2*k+2][i][j]; | |
Pair &e = max_p[k][p.a == max_va[k]][p.b == max_vb[k]]; | |
e = max(e, p); | |
} | |
} | |
void _add_val(ll xa, ll xb, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
addall(k, xa, xb); | |
return; | |
} | |
push(k); | |
_add_val(xa, xb, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(xa, xb, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
void _update_min_a(ll xa, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_va[k] <= xa) { | |
return; | |
} | |
if(a <= l && r <= b && smax_va[k] < xa) { | |
update_node_max(k, xa, inf); | |
return; | |
} | |
push(k); | |
_update_min_a(xa, a, b, 2*k+1, l, (l+r)/2); | |
_update_min_a(xa, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
void _update_min_b(ll xb, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_vb[k] <= xb) { | |
return; | |
} | |
if(a <= l && r <= b && smax_vb[k] < xb) { | |
update_node_max(k, inf, xb); | |
return; | |
} | |
push(k); | |
_update_min_b(xb, a, b, 2*k+1, l, (l+r)/2); | |
_update_min_b(xb, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
ll _query_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
auto &mp = max_p[k]; | |
// 4種類のペアのうち a_i + b_i の最大値を返す | |
return max(max(mp[0][0], mp[0][1]), max(mp[1][0], mp[1][1])).sum(); | |
} | |
push(k); | |
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
public: | |
SegmentTree(int n, ll *a, ll *b) : n(n) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
for(int i=0; i<n; ++i) { | |
ll va = (a != nullptr ? a[i] : 0); | |
ll vb = (b != nullptr ? b[i] : 0); | |
max_va[n0-1+i] = va; | |
max_vb[n0-1+i] = vb; | |
smax_va[n0-1+i] = smax_vb[n0-1+i] = -inf; | |
// 一度全てのペアを (-∞, -∞) で初期化 | |
for(int p=0; p<2; ++p) for(int q=0; q<2; ++q) { | |
max_p[n0-1+i][p][q] = Pair{-inf, -inf}; | |
} | |
// P_{11} のみ ペア(va, vb) を持つ | |
max_p[n0-1+i][1][1] = Pair{va, vb}; | |
} | |
for(int i=n; i<n0; ++i) { | |
max_va[n0-1+i] = smax_va[n0-1+i] = -inf; | |
max_vb[n0-1+i] = smax_vb[n0-1+i] = -inf; | |
for(int p=0; p<2; ++p) for(int q=0; q<2; ++q) { | |
max_p[n0-1+i][p][q] = Pair{-inf, -inf}; | |
} | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// 区間[a, b) について a_i の値を min(a_i, x) に更新 | |
void update_min_a(int a, int b, ll x) { | |
_update_min_a(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について b_i の値を min(b_i, x) に更新 | |
void update_min_b(int a, int b, ll x) { | |
_update_min_b(x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について a_i の値を a_i + x に更新 | |
void add_val_a(int a, int b, ll x) { | |
_add_val(x, 0, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) について b_i の値を b_i + x に更新 | |
void add_val_b(int a, int b, ll x) { | |
_add_val(0, x, a, b, 0, 0, n0); | |
} | |
// 区間[a, b) の中の max(a_i + b_i) の値を求める | |
ll query_max(int a, int b) { | |
return _query_max(a, b, 0, 0, n0); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment