Skip to content

Instantly share code, notes, and snippets.

@tjkendev
Last active May 18, 2019 12:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tjkendev/0965bdcf9e3c1063a7f2e76005d285ca to your computer and use it in GitHub Desktop.
Save tjkendev/0965bdcf9e3c1063a7f2e76005d285ca to your computer and use it in GitHub Desktop.
Segment Tree Beatsの例題実装 (Task 3, Task 4まわり)
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats
// - l<=i<r について、 a_i の値を min(a_i, x) に更新
// - l<=i<r の中の a_i の最大値を求める
// - l<=i<r の a_i の和を求める
#define N 10003
class SegmentTree {
const ll inf = 1e18;
int n, n0;
ll max_v[4*N], smax_v[4*N], max_c[4*N];
ll ladd[4*N]; // 区間加算クエリにおけるlazy tag
ll len[4*N]; // ノードの区間に含まれる要素の数
ll sum_a[4*N]; // 最大値の総和
ll sum_b[4*N]; // 非最大値の総和
void update_node_max(int k, ll x) {
// 最大値の総和だけ加算
sum_a[k] += (x - max_v[k]) * max_c[k];
max_v[k] = x;
}
void addall(int k, ll x) {
max_v[k] += x; smax_v[k] += x;
// 最大値の総和・非最大値の総和それぞれに加算
sum_a[k] += x * max_c[k];
sum_b[k] += x * (len[k] - max_c[k]);
ladd[k] += x;
}
void push(int k) {
if(ladd[k]) {
addall(2*k+1, ladd[k]);
addall(2*k+2, ladd[k]);
ladd[k] = 0;
}
if(max_v[k] < max_v[2*k+1]) {
update_node_max(2*k+1, max_v[k]);
}
if(max_v[k] < max_v[2*k+2]) {
update_node_max(2*k+2, max_v[k]);
}
}
void update(int k) {
// 最大値の総和と非最大値の総和を子ノードから更新
sum_b[k] = sum_b[2*k+1] + sum_b[2*k+2];
if(max_v[2*k+1] < max_v[2*k+2]) {
max_v[k] = max_v[2*k+2];
max_c[k] = max_c[2*k+2];
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
sum_a[k] = sum_a[2*k+2];
sum_b[k] += sum_a[2*k+1];
} else if(max_v[2*k+1] > max_v[2*k+2]) {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1];
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
sum_a[k] = sum_a[2*k+1];
sum_b[k] += sum_a[2*k+2];
} else {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1] + max_c[2*k+2];
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
sum_a[k] = sum_a[2*k+1] + sum_a[2*k+2];
}
}
void _update_min(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_v[k] <= x) {
return;
}
if(a <= l && r <= b && smax_v[k] < x) {
update_node_max(k, x);
return;
}
push(k);
_update_min(x, a, b, 2*k+1, l, (l+r)/2);
_update_min(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
ll _query_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
return max_v[k];
}
push(k);
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
ll _query_sum(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
// 最大値の総和・非最大値の総和を合わせる
return sum_a[k] + sum_b[k];
}
push(k);
ll lv = _query_sum(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_sum(a, b, 2*k+2, (l+r)/2, r);
return lv + rv;
}
public:
SegmentTree(int n, ll *a) : n(n) {
n0 = 1;
while(n0 < n) n0 <<= 1;
len[0] = n0;
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1);
for(int i=0; i<n; ++i) {
max_v[n0-1+i] = sum_a[n0-1+i] = a[i];
smax_v[n0-1+i] = -inf; sum_b[n0-1+i] = 0;
max_c[n0-1+i] = 1;
}
for(int i=n; i<n0; ++i) {
max_v[n0-1+i] = smax_v[n0-1+i] = -inf;
sum_a[n0-1+i] = sum_b[n0-1+i] = max_c[n0-1+i] = 0;
}
for(int i=n0-2; i>=0; i--) update(i);
}
void update_min(int a, int b, ll x) {
return _update_min(x, a, b, 0, 0, n0);
}
ll query_max(int a, int b) {
return _query_max(a, b, 0, 0, n0);
}
ll query_sum(int a, int b) {
return _query_sum(a, b, 0, 0, n0);
}
};
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats (Task 3)
// - l<=i<r について a_i の値を min(a_i, x) に更新
// - l<=i<r について a_i の値を max(a_i, x) に更新
// - l<=i<r について a_i の値を a_i+x に更新
// - l<=i<r の b_i の和を求める
//
// 各クエリで、a_i が変化した場合に b_i に1加算する
#define N 10003
class SegmentTree {
const ll inf = 1e18;
int n, n0;
ll max_v[4*N], smax_v[4*N], max_c[4*N];
ll min_v[4*N], smin_v[4*N], min_c[4*N];
ll len[4*N]; // len[k]: ノードkの区間に含まれる要素数
ll ladd[4*N]; // ladd[k]: RAQによって発生する、区間の a_i に加算するためのlazy tag
ll sum_b[4*N]; // b_i の区間総和
ll min_a[4*N]; // a_i が最小値の b_i に加算するlazy tag (区間chmaxクエリ用)
ll max_a[4*N]; // a_i が最大値の b_i に加算するlazy tag (区間chminクエリ用)
ll all_a[4*N]; // 区間内の b_i 全てに加算するlazy tag (区間加算クエリ用)
void update_node_max(int k, ll x) {
if(max_v[k] == min_v[k]) {
max_v[k] = min_v[k] = x;
} else if(max_v[k] == smin_v[k]) {
max_v[k] = smin_v[k] = x;
} else {
max_v[k] = x;
}
}
void update_node_min(int k, ll x) {
if(max_v[k] == min_v[k]) {
max_v[k] = min_v[k] = x;
} else if(smax_v[k] == min_v[k]) {
min_v[k] = smax_v[k] = x;
} else {
min_v[k] = x;
}
}
// 親ノードpの最大値・最小値の状態を元に子ノードkの b_i の区間総和を更新
// lazy valueの伝搬
void add_b(int p, int k, ll mi, ll ma, ll aa) {
// 親ノードpの最小値とノードkの最小値が等しい場合のみ更新と伝搬を行う
if(p == -1 || min_v[p] == min_v[k]) {
// S_B への加算
sum_b[k] += mi * min_c[k];
min_a[k] += mi;
}
// 親ノードpの最大値とノードkの最大値が等しい場合のみ更新と伝搬を行う
if(p == -1 || max_v[p] == max_v[k]) {
// S_A への加算
sum_b[k] += ma * max_c[k];
max_a[k] += ma;
}
// 全体(S_A + S_B + S_C)に加算する値の伝搬
sum_b[k] += aa * len[k];
all_a[k] += aa;
}
void addall(int k, ll x) {
max_v[k] += x; smax_v[k] += x;
min_v[k] += x; smin_v[k] += x;
ladd[k] += x;
}
void push(int k) {
if(ladd[k]) {
addall(2*k+1, ladd[k]);
addall(2*k+2, ladd[k]);
ladd[k] = 0;
}
if(max_v[k] < max_v[2*k+1]) {
update_node_max(2*k+1, max_v[k]);
}
if(min_v[2*k+1] < min_v[k]) {
update_node_min(2*k+1, min_v[k]);
}
if(max_v[k] < max_v[2*k+2]) {
update_node_max(2*k+2, max_v[k]);
}
if(min_v[2*k+2] < min_v[k]) {
update_node_min(2*k+2, min_v[k]);
}
// 子ノードの最小値/最大値を更新した後に区間総和を更新
if(min_a[k] != 0 || max_a[k] != 0 || all_a[k] != 0) {
add_b(k, 2*k+1, min_a[k], max_a[k], all_a[k]);
add_b(k, 2*k+2, min_a[k], max_a[k], all_a[k]);
min_a[k] = max_a[k] = all_a[k] = 0;
}
}
void update(int k) {
// b_i の区間総和の値を子ノードから計算し直す
sum_b[k] = sum_b[2*k+1] + sum_b[2*k+2];
if(max_v[2*k+1] < max_v[2*k+2]) {
max_v[k] = max_v[2*k+2];
max_c[k] = max_c[2*k+2];
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
} else if(max_v[2*k+1] > max_v[2*k+2]) {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1];
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
} else {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1] + max_c[2*k+2];
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
}
if(min_v[2*k+1] < min_v[2*k+2]) {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1];
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]);
} else if(min_v[2*k+1] > min_v[2*k+2]) {
min_v[k] = min_v[2*k+2];
min_c[k] = min_c[2*k+2];
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]);
} else {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1] + min_c[2*k+2];
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]);
}
}
void _update_min(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_v[k] <= x) {
return;
}
if(a <= l && r <= b && smax_v[k] < x) {
add_b(-1, k, 0, 1, 0);
update_node_max(k, x);
return;
}
push(k);
_update_min(x, a, b, 2*k+1, l, (l+r)/2);
_update_min(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
void _update_max(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || x <= min_v[k]) {
return;
}
if(a <= l && r <= b && x < smin_v[k]) {
add_b(-1, k, 1, 0, 0);
update_node_min(k, x);
return;
}
push(k);
_update_max(x, a, b, 2*k+1, l, (l+r)/2);
_update_max(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
void _add_val(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
add_b(-1, k, 0, 0, 1);
addall(k, x);
return;
}
push(k);
_add_val(x, a, b, 2*k+1, l, (l+r)/2);
_add_val(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
ll _query_sum(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
return sum_b[k];
}
push(k);
ll lv = _query_sum(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_sum(a, b, 2*k+2, (l+r)/2, r);
return lv + rv;
}
public:
SegmentTree(int n, ll *a) : n(n) {
n0 = 1;
while(n0 < n) n0 <<= 1;
len[0] = n0;
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1);
for(int i=0; i<2*n0-1; ++i) {
min_a[i] = max_a[i] = all_a[i] = 0;
sum_b[i] = 0;
ladd[i] = 0;
}
for(int i=0; i<n; ++i) {
max_v[n0-1+i] = min_v[n0-1+i] = a[i];
smax_v[n0-1+i] = -inf;
smin_v[n0-1+i] = inf;
max_c[n0-1+i] = min_c[n0-1+i] = 1;
}
for(int i=n; i<n0; ++i) {
max_v[n0-1+i] = smax_v[n0-1+i] = -inf;
min_v[n0-1+i] = smin_v[n0-1+i] = inf;
max_c[n0-1+i] = min_c[n0-1+i] = 0;
}
for(int i=n0-2; i>=0; i--) update(i);
}
// 区間[a, b) について a_i の値を min(a_i, x) に更新
void update_min(int a, int b, ll x) {
return _update_min(x, a, b, 0, 0, n0);
}
// 区間[a, b) について a_i の値を max(a_i, x) に更新
void update_max(int a, int b, ll x) {
return _update_max(x, a, b, 0, 0, n0);
}
// 区間[a, b) について a_i の値を a_i+x に更新
void add_val(int a, int b, ll x) {
if(!x) return;
_add_val(x, a, b, 0, 0, n0);
}
// 区間[a, b) の b_i の総和を求める
ll query_sum(int a, int b) {
return _query_sum(a, b, 0, 0, n0);
}
};
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats
// - l<=i<r について A_i の値を min(A_i, x) に更新
// - l<=i<r について B_i の値を x に更新する
// - l<=i<r の中の A_i + B_i の最大値を求める
#define N 10003
class SegmentTree {
const static ll inf = 1e18;
struct Pair {
ll a, b;
ll sum() const { return a + b; }
// Pair同士を比較する時は a_i + b_i の値が大きい方を大きいとする
bool operator<(const Pair &other) const {
return this->sum() < other.sum();
}
};
int n, n0;
// a_i の最大値、二番目の最大値
ll max_v[4*N], smax_v[4*N];
// 遅延させている b_i の更新値
ll lval[4*N];
// a_i が最大値となる中で a_i + b_i が最大値となるペア P_A = (a_i, b_i)
Pair max_p[4*N];
// a_i が非最大値となる中で a_i + b_i が最大値となるペア P_B = (a_i, b_i)
Pair nmax_p[4*N];
void update_node_max(int k, ll x) {
// ペアP_A の a_i の値を x に更新
// ペアP_B の a_i は最大値未満であるため更新は必要ない
max_p[k].a = x;
max_v[k] = x;
}
void update_val(int k, ll x) {
// ペアP_A, P_Bに含まれる b_i の値を x に更新
max_p[k].b = x;
if(nmax_p[k].b != -inf) {
nmax_p[k].a = smax_v[k];
nmax_p[k].b = x;
}
lval[k] = x;
}
// 親ノードから子ノードに情報を伝搬
void push(int k) {
// a_i の更新
if(max_v[k] < max_v[2*k+1]) {
update_node_max(2*k+1, max_v[k]);
}
if(max_v[k] < max_v[2*k+2]) {
update_node_max(2*k+2, max_v[k]);
}
// b_i の更新
if(lval[k] != inf) {
update_val(2*k+1, lval[k]);
update_val(2*k+2, lval[k]);
lval[k] = inf;
}
}
// 子ノードから親ノードの情報を更新
void update(int k) {
// 2*k+1, 2*k+2 の非最大値のうちの最大値を 非最大値として更新
nmax_p[k] = max(nmax_p[2*k+1], nmax_p[2*k+2]);
if(max_v[2*k+1] < max_v[2*k+2]) {
max_v[k] = max_v[2*k+2];
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
// 2*k+2 は最大値、2*k+1 は非最大値 として更新
max_p[k] = max_p[2*k+2];
nmax_p[k] = max(nmax_p[k], max_p[2*k+1]);
} else if(max_v[2*k+1] > max_v[2*k+2]) {
max_v[k] = max_v[2*k+1];
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
// 2*k+1 は最大値、2*k+2 は非最大値 として更新
max_p[k] = max_p[2*k+1];
nmax_p[k] = max(nmax_p[k], max_p[2*k+2]);
} else {
max_v[k] = max_v[2*k+1];
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
// 2*k+1, 2*k+2 共に最大値 として更新
max_p[k] = max(max_p[2*k+1], max_p[2*k+2]);
}
}
// 区間[a, b)に含まれる a_i の値を min(a_i, x) に更新
void _update_min(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_v[k] <= x) {
return;
}
if(a <= l && r <= b && smax_v[k] < x) {
update_node_max(k, x);
return;
}
push(k);
_update_min(x, a, b, 2*k+1, l, (l+r)/2);
_update_min(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b)に含まれる b_i の値を x に更新
void _update_val(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
update_val(k, x);
return;
}
push(k);
_update_val(x, a, b, 2*k+1, l, (l+r)/2);
_update_val(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b)の中の a_i + b_i の最大値を求める
ll _query_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
// ペアP_A, P_Bのうち、a_i + b_i が最大な方の値を返す
return max(max_p[k], nmax_p[k]).sum();
}
push(k);
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
public:
SegmentTree(int n, ll *a, ll *b) : n(n) {
n0 = 1;
while(n0 < n) n0 <<= 1;
for(int i=0; i<2*n0-1; ++i) lval[i] = inf;
for(int i=0; i<n; ++i) {
max_v[n0-1+i] = a[i];
smax_v[n0-1+i] = -inf;
// [i:i+1)の最大値ペアは (a_i, b_i)
max_p[n0-1+i] = Pair{a[i], b[i]};
// [i:i+1)の2番目の最大値ペアは (-∞, -∞)
nmax_p[n0-1+i] = Pair{-inf, -inf};
}
for(int i=n; i<n0; ++i) {
max_v[n0-1+i] = smax_v[n0-1+i] = -inf;
max_p[n0-1+i] = nmax_p[n0-1+i] = Pair{-inf, -inf};
}
for(int i=n0-2; i>=0; i--) update(i);
}
// 区間[a, b) について a_i の値を min(a_i, x) に更新
void update_min(int a, int b, ll x) {
return _update_min(x, a, b, 0, 0, n0);
}
// 区間[a, b) について b_i の値を x に更新
void update_val(int a, int b, ll x) {
_update_val(x, a, b, 0, 0, n0);
}
// 区間[a, b) の中で max(a_i + b_i) を求める
ll query_max(int a, int b) {
return _query_max(a, b, 0, 0, n0);
}
};
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats (Task 4)
// - l<=i<r について、 A_i の値を min(A_i, x) に更新
// - l<=i<r について、 B_i の値を min(B_i, x) に更新
// - l<=i<r について、 A_i の値を A_i + x に更新
// - l<=i<r について、 B_i の値を B_i + x に更新
// - l<=i<r の A_i + B_i の最大値を求める
#define N 10003
class SegmentTree {
const static ll inf = 1e18;
struct Pair {
ll a, b;
ll sum() const { return a+b; }
bool operator<(const Pair& other) const {
return this->sum() < other.sum();
}
};
int n, n0;
// == ノードkにおける4種類のペアの情報
// max_p[k][1][1]: a_i と b_i が最大値となる中で a_i+b_i が最大となるペア P_{11} = (a_i, b_i)
// max_p[k][1][0]: a_i が最大値、 b_i が非最大値となる中で a_i+b_i が最大となるペア P_{10} = (a_i, b_i)
// max_p[k][0][1]: a_i が非最大値、 b_i が最大値となる中で a_i+b_i が最大となるペア P_{01} = (a_i, b_i)
// max_p[k][0][0]: a_i と b_i が非最大値となる中で a_i+b_i が最大となるペア P_{00} = (a_i, b_i)
Pair max_p[4*N][2][2];
// 数列AとBに関する最大値、二番目の最大値
ll max_va[4*N], smax_va[4*N];
ll max_vb[4*N], smax_vb[4*N];
// 区間加算クエリで 遅延させている加算値
ll ladd_a[4*N], ladd_b[4*N];
// ノードの a_i, b_i の区間最大値を親ノードから更新する
void update_node_max(int k, ll xa, ll xb) {
if(xa < max_va[k]) {
// ペア P_{11}, P_{10} の a_i の値を更新
if(max_p[k][1][1].a != -inf) max_p[k][1][1].a = xa;
if(max_p[k][1][0].a != -inf) max_p[k][1][0].a = xa;
max_va[k] = xa;
}
if(xb < max_vb[k]) {
// ペア P_{11}, P_{01} の b_i の値を更新
if(max_p[k][1][1].b != -inf) max_p[k][1][1].b = xb;
if(max_p[k][0][1].b != -inf) max_p[k][0][1].b = xb;
max_vb[k] = xb;
}
}
// 区間内の a_i に xa を加算し、区間内の b_i に xb を加算する
inline void addall(int k, ll xa, ll xb) {
if(xa != 0) {
// 4種類のペア全ての a_i の値に加算
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) {
if(max_p[k][i][j].a != -inf) max_p[k][i][j].a += xa;
}
max_va[k] += xa;
if(smax_va[k] != -inf) smax_va[k] += xa;
ladd_a[k] += xa;
}
if(xb != 0) {
// 4種類のペア全ての b_i の値に加算
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) {
if(max_p[k][i][j].b != -inf) max_p[k][i][j].b += xb;
}
max_vb[k] += xb;
if(smax_vb[k] != -inf) smax_vb[k] += xb;
ladd_b[k] += xb;
}
}
// pushdown: 親ノードから子ノードへの伝搬
void push(int k) {
if(ladd_a[k] != 0 || ladd_b[k] != 0) {
addall(2*k+1, ladd_a[k], ladd_b[k]);
addall(2*k+2, ladd_a[k], ladd_b[k]);
ladd_a[k] = ladd_b[k] = 0;
}
update_node_max(2*k+1, max_va[k], max_vb[k]);
update_node_max(2*k+2, max_va[k], max_vb[k]);
}
// ノードk の区間最大値と2番目の区間最大値の情報を更新
inline void _update_max_v(int k, ll *max_v, ll *smax_v) {
if(max_v[2*k+1] < max_v[2*k+2]) {
max_v[k] = max_v[2*k+2];
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
} else if(max_v[2*k+1] > max_v[2*k+2]) {
max_v[k] = max_v[2*k+1];
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
} else {
max_v[k] = max_v[2*k+1];
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
}
}
// update: 子ノードから親ノードの情報を更新
void update(int k) {
// a_i と b_i の区間最大値の情報を更新
_update_max_v(k, max_va, smax_va);
_update_max_v(k, max_vb, smax_vb);
// 更新する前に、親ノードの4種類のペアを一度初期化
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) {
max_p[k][i][j] = Pair{-inf, -inf};
}
// 左子ノード2*k+1 から 親ノードk の情報を更新
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) {
Pair &p = max_p[2*k+1][i][j];
Pair &e = max_p[k][p.a == max_va[k]][p.b == max_vb[k]];
e = max(e, p);
}
// 右子ノード2*k+2 から 親ノードk の情報を更新
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j) {
Pair &p = max_p[2*k+2][i][j];
Pair &e = max_p[k][p.a == max_va[k]][p.b == max_vb[k]];
e = max(e, p);
}
}
void _add_val(ll xa, ll xb, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
addall(k, xa, xb);
return;
}
push(k);
_add_val(xa, xb, a, b, 2*k+1, l, (l+r)/2);
_add_val(xa, xb, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
void _update_min_a(ll xa, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_va[k] <= xa) {
return;
}
if(a <= l && r <= b && smax_va[k] < xa) {
update_node_max(k, xa, inf);
return;
}
push(k);
_update_min_a(xa, a, b, 2*k+1, l, (l+r)/2);
_update_min_a(xa, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
void _update_min_b(ll xb, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_vb[k] <= xb) {
return;
}
if(a <= l && r <= b && smax_vb[k] < xb) {
update_node_max(k, inf, xb);
return;
}
push(k);
_update_min_b(xb, a, b, 2*k+1, l, (l+r)/2);
_update_min_b(xb, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
ll _query_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
auto &mp = max_p[k];
// 4種類のペアのうち a_i + b_i の最大値を返す
return max(max(mp[0][0], mp[0][1]), max(mp[1][0], mp[1][1])).sum();
}
push(k);
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
public:
SegmentTree(int n, ll *a, ll *b) : n(n) {
n0 = 1;
while(n0 < n) n0 <<= 1;
for(int i=0; i<n; ++i) {
ll va = (a != nullptr ? a[i] : 0);
ll vb = (b != nullptr ? b[i] : 0);
max_va[n0-1+i] = va;
max_vb[n0-1+i] = vb;
smax_va[n0-1+i] = smax_vb[n0-1+i] = -inf;
// 一度全てのペアを (-∞, -∞) で初期化
for(int p=0; p<2; ++p) for(int q=0; q<2; ++q) {
max_p[n0-1+i][p][q] = Pair{-inf, -inf};
}
// P_{11} のみ ペア(va, vb) を持つ
max_p[n0-1+i][1][1] = Pair{va, vb};
}
for(int i=n; i<n0; ++i) {
max_va[n0-1+i] = smax_va[n0-1+i] = -inf;
max_vb[n0-1+i] = smax_vb[n0-1+i] = -inf;
for(int p=0; p<2; ++p) for(int q=0; q<2; ++q) {
max_p[n0-1+i][p][q] = Pair{-inf, -inf};
}
}
for(int i=n0-2; i>=0; i--) update(i);
}
// 区間[a, b) について a_i の値を min(a_i, x) に更新
void update_min_a(int a, int b, ll x) {
_update_min_a(x, a, b, 0, 0, n0);
}
// 区間[a, b) について b_i の値を min(b_i, x) に更新
void update_min_b(int a, int b, ll x) {
_update_min_b(x, a, b, 0, 0, n0);
}
// 区間[a, b) について a_i の値を a_i + x に更新
void add_val_a(int a, int b, ll x) {
_add_val(x, 0, a, b, 0, 0, n0);
}
// 区間[a, b) について b_i の値を b_i + x に更新
void add_val_b(int a, int b, ll x) {
_add_val(0, x, a, b, 0, 0, n0);
}
// 区間[a, b) の中の max(a_i + b_i) の値を求める
ll query_max(int a, int b) {
return _query_max(a, b, 0, 0, n0);
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment