Last active
June 5, 2019 16:35
-
-
Save tjkendev/ad154c067a02648f2c2d8194abe6022d to your computer and use it in GitHub Desktop.
Segment Tree Beats (Historic Informationまわり) の実装
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats (Historic Information) | |
// - l<=i<r について a_i の値に x を加える | |
// - l<=i<r の中の a_i の最大値を求める | |
// - l<=i<r の中の b_i の総和を求める | |
// - l<=i<r の中の b_i の最大値を求める | |
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i)) | |
#define N 10003 | |
class SegmentTree { | |
static const ll inf = 1e18; | |
int n0; | |
// cur_s : a_i の区間総和 | |
// cur_ma: a_i の区間最大値 | |
ll cur_s[4*N], cur_ma[4*N]; | |
// 区間加算クエリ用 (ladd: lazy tag, len: 区間要素数) | |
ll ladd[4*N], len[4*N]; | |
// d_i の最小値まわりの情報を持つ構造体 | |
struct HistVal { | |
// min_v: d_i の最小値 | |
// smin_v: 二番目の最小値 | |
// min_c: 最小値の個数 | |
// sum: d_iの総和 | |
ll min_v, smin_v, min_c, sum; | |
// m_hmax: d_i が最小値の a_i + d_i の最大値 | |
// nm_hmax: d_i が非最小値の a_i + d_i の最大値 | |
ll m_hmax, nm_hmax; | |
// a_i の初期値が x の初期化処理 | |
void init(ll x) { | |
min_v = 0; smin_v = inf; min_c = 1; sum = 0; | |
m_hmax = x; nm_hmax = -inf; | |
} | |
// a_i を使わない場合の初期化処理 | |
void init_empty() { | |
min_v = smin_v = inf; min_c = 0; sum = 0; | |
m_hmax = nm_hmax = -inf; | |
} | |
// d_i の最小値を x に更新 | |
inline void update_min(ll x) { | |
if(min_v < x) { | |
sum += (x - min_v) * min_c; | |
// a_i + d_i の最大値はここのみで変化する | |
m_hmax += (x - min_v); | |
min_v = x; | |
} | |
} | |
// l と r の情報をマージして更新 | |
// (片方に自身を指定して使っても問題ないようになっている) | |
inline void merge(HistVal &l, HistVal &r) { | |
sum = l.sum + r.sum; | |
nm_hmax = max(l.nm_hmax, r.nm_hmax); | |
if(l.min_v < r.min_v) { | |
smin_v = min(l.smin_v, r.min_v); | |
min_v = l.min_v; | |
min_c = l.min_c; | |
nm_hmax = max(nm_hmax, r.m_hmax); | |
m_hmax = l.m_hmax; | |
} else if(l.min_v > r.min_v) { | |
smin_v = min(l.min_v, r.smin_v); | |
min_v = r.min_v; | |
min_c = r.min_c; | |
nm_hmax = max(nm_hmax, l.m_hmax); | |
m_hmax = r.m_hmax; | |
} else { | |
min_v = l.min_v; | |
smin_v = min(l.smin_v, r.smin_v); | |
min_c = l.min_c + r.min_c; | |
m_hmax = max(l.m_hmax, r.m_hmax); | |
} | |
} | |
// 最小値・二番目の最小値への加算処理 | |
void add(ll x) { | |
if(min_v != inf) min_v += x; | |
if(smin_v != inf) smin_v += x; | |
} | |
// a_i + d_i の最大値を返す | |
ll hmax() const { | |
return max(m_hmax, nm_hmax); | |
} | |
}; | |
// 各ノードごとに d_i の最小値情報を持つ | |
HistVal val_d[4*N]; | |
// ノードk に含まれる各 a_i に a を加算 | |
void addall(int k, ll a) { | |
// a_i の区間総和に加算 | |
cur_s[k] += a * len[k]; | |
cur_ma[k] += a; | |
// a_i + d_i の最大値に加算されるが、この後相殺されるため加算は不要 | |
//val_d[k].m_hmax += a; val_d[k].nm_hmax += a; | |
// d_i の最小値・二番目の最小値と区間総和に加算 | |
val_d[k].add(-a); | |
val_d[k].sum -= a * len[k]; | |
//val_d[k].m_hmax -= a; val_d[k].nm_hmax -= a; | |
ladd[k] += a; | |
} | |
void push(int k) { | |
if(k >= n0-1) return; | |
if(ladd[k] != 0) { | |
addall(2*k+1, ladd[k]); | |
addall(2*k+2, ladd[k]); | |
ladd[k] = 0; | |
} | |
val_d[2*k+1].update_min(val_d[k].min_v); | |
val_d[2*k+2].update_min(val_d[k].min_v); | |
} | |
void update(int k) { | |
cur_s[k] = cur_s[2*k+1] + cur_s[2*k+2]; | |
cur_ma[k] = max(cur_ma[2*k+1], cur_ma[2*k+2]); | |
val_d[k].merge(val_d[2*k+1], val_d[2*k+2]); | |
} | |
// (内部用) d_i <- max(d_i, 0) で更新 | |
void _update_dmax(int k, int l, int r) { | |
// break condition: d_i の最小値(val_d[k].min_v)が 0 以上 | |
if(l == r || 0 <= val_d[k].min_v) { | |
return; | |
} | |
// tag condition: d_i の二番目の最小値(val_d[k].smin_v)が 0 より大きい | |
if(0 < val_d[k].smin_v) { | |
// d_i の最小値を 0 に更新 | |
val_d[k].update_min(0); | |
return; | |
} | |
push(k); | |
_update_dmax(2*k+1, l, (l+r)/2); | |
_update_dmax(2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b) の a_i に x を加算する | |
void _add_val(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
// a_i に x を加算 | |
addall(k, x); | |
// 区間[l, r) の d_i を max(d_i, 0) に更新 | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_add_val(x, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b) における a_i の区間最大値を返す | |
ll _query_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
return cur_ma[k]; | |
} | |
push(k); | |
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
// 区間[a, b) における historic maximal value の区間最大値を返す | |
ll _query_hmax_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
return val_d[k].hmax(); | |
} | |
push(k); | |
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
// 区間[a, b) における historic maximal value の区間総和を返す | |
ll _query_hmax_sum(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
return cur_s[k] + val_d[k].sum; | |
} | |
push(k); | |
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r); | |
return lv + rv; | |
} | |
public: | |
SegmentTree(int n, ll *a) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
len[0] = n0; | |
for(int i=0; i<2*n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
for(int i=0; i<n; ++i) { | |
cur_s[n0-1+i] = cur_ma[n0-1+i] = a[i]; | |
val_d[n0-1+i].init(a[i]); | |
} | |
for(int i=n; i<n0; ++i) { | |
cur_s[n0-1+i] = cur_ma[n0-1+i] = 0; | |
val_d[n0-1+i].init_empty(); | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// l<=i<r について a_i の値を a_i + x に更新 | |
void add_val(int a, int b, ll x) { | |
_add_val(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の b_i の区間総和 | |
ll query_hmax_sum(int a, int b) { | |
return _query_hmax_sum(a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の b_i の区間最大値 | |
ll query_hmax_max(int a, int b) { | |
return _query_hmax_max(a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の a_i の区間最大値 | |
ll query_max(int a, int b) { | |
return _query_max(a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using namespace std; | |
using ll = long long; | |
// Segment Tree with Lazy Propagation | |
// - l<=i<r について a_i を a_i + x に更新 | |
// - l<=i<r の s_i の区間総和を求める | |
#define N 10004 | |
class SegmentTree { | |
ll n0; | |
ll va[4*N], vb[4*N]; | |
ll ladd_a[4*N], ladd_b[4*N]; | |
ll len[4*N]; | |
void addall(int k, ll x, ll s) { | |
va[k] += x * len[k]; | |
vb[k] += s * len[k]; | |
ladd_a[k] += x; | |
ladd_b[k] += s; | |
} | |
void push(int k) { | |
addall(2*k+1, ladd_a[k], ladd_b[k]); | |
addall(2*k+2, ladd_a[k], ladd_b[k]); | |
ladd_a[k] = ladd_b[k] = 0; | |
} | |
void update(int k) { | |
va[k] = va[2*k+1] + va[2*k+2]; | |
vb[k] = vb[2*k+1] + vb[2*k+2]; | |
} | |
void _add_val(ll x, ll t, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
addall(k, x, -t*x); | |
return; | |
} | |
push(k); | |
_add_val(x, t, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(x, t, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
ll _query_sum(ll t, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
return va[k] * t + vb[k]; | |
} | |
push(k); | |
ll lv = _query_sum(t, a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_sum(t, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
return lv + rv; | |
} | |
public: | |
SegmentTree(int n, ll *a) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
len[0] = n0; | |
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
for(int i=0; i<n; ++i) { | |
va[n0-1+i] = vb[n0-1+i] = a[i]; | |
} | |
for(int i=n; i<n0; ++i) { | |
va[n0-1+i] = vb[n0-1+i] = 0; | |
} | |
for(int i=n0-2; i>=0; --i) update(i); | |
} | |
void add_val(int a, int b, ll x, ll t) { | |
_add_val(x, t, a, b, 0, 0, n0); | |
} | |
ll query_sum(int a, int b, ll t) { | |
return _query_sum(t, a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats (Historic Information) | |
// - i<=i<r について a_i の値を max(a_i, x) に更新 | |
// - l<=i<r について a_i の値に x を加える | |
// - l<=i<r の中の a_i の最小値を求める | |
// - l<=i<r の中の b_i の総和を求める | |
// - l<=i<r の中の b_i の最大値を求める | |
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i)) | |
#define N 10003 | |
class SegmentTree { | |
const static ll inf = 1e18; | |
int n0; | |
// 区間加算クエリ用 (ladd: lazy tag, len: 区間要素数) | |
ll len[4*N], ladd[4*N]; | |
// a_i の最小値・二番目の最小値・最小値の個数 | |
ll min_v[4*N], smin_v[4*N], min_c[4*N]; | |
// a_i の区間総和 | |
ll sum[4*N]; | |
struct HistVal { | |
// min_v: d_i の最小値 | |
// smin_v: 二番目の最小値 | |
// min_c: 最小値の個数 | |
// sum: d_iの総和 | |
ll min_v, smin_v, min_c, sum; | |
// m_hmax: d_i が最小値の a_i + d_i の最大値 | |
// nm_hmax: d_i が非最小値の a_i + d_i の最大値 | |
ll m_hmax, nm_hmax; | |
// a_i = x でノードを初期化 | |
void init(ll x) { | |
min_v = 0; smin_v = inf; min_c = 1; sum = 0; | |
m_hmax = x; nm_hmax = -inf; | |
} | |
// 要素なしでノードを初期化 | |
void init_empty() { | |
min_v = smin_v = inf; min_c = 0; sum = 0; | |
m_hmax = nm_hmax = -inf; | |
} | |
// d_i の最小値を x に更新 | |
// pushdown, chmaxの際に使う | |
inline void update_min(ll x) { | |
if(min_v < x) { | |
sum += (x - min_v) * min_c; | |
m_hmax += (x - min_v); | |
min_v = x; | |
} | |
} | |
// l と r の情報をマージして更新 | |
// updateの際に使う | |
inline void merge(HistVal &l, HistVal &r) { | |
sum = l.sum + r.sum; | |
nm_hmax = max(l.nm_hmax, r.nm_hmax); | |
if(l.min_v < r.min_v) { | |
smin_v = min(l.smin_v, r.min_v); | |
min_v = l.min_v; | |
min_c = l.min_c; | |
nm_hmax = max(nm_hmax, r.m_hmax); | |
m_hmax = l.m_hmax; | |
} else if(l.min_v > r.min_v) { | |
smin_v = min(l.min_v, r.smin_v); | |
min_v = r.min_v; | |
min_c = r.min_c; | |
nm_hmax = max(nm_hmax, l.m_hmax); | |
m_hmax = r.m_hmax; | |
} else { | |
min_v = l.min_v; | |
smin_v = min(l.smin_v, r.smin_v); | |
min_c = l.min_c + r.min_c; | |
m_hmax = max(l.m_hmax, r.m_hmax); | |
} | |
} | |
// 要素にxを加算 | |
void add(ll x) { | |
if(min_v != inf) min_v += x; | |
if(smin_v != inf) smin_v += x; | |
} | |
// a_i + d_i の最大値を返す | |
ll hmax() const { | |
return max(m_hmax, nm_hmax); | |
} | |
}; | |
// min_d: a_i が最小値の i における d_i まわりの情報 | |
// nmin_d: a_i が非最小値の i における d_i まわりの情報 | |
HistVal min_d[4*N], nmin_d[4*N]; | |
// a_i の最小値を x に更新する | |
void update_node_min(int k, ll x) { | |
sum[k] += (x - min_v[k]) * min_c[k]; | |
min_d[k].add(min_v[k] - x); | |
min_d[k].sum += (min_v[k] - x) * min_c[k]; | |
min_v[k] = x; | |
} | |
void addall(int k, ll a) { | |
min_v[k] += a; | |
if(smin_v[k] != inf) smin_v[k] += a; | |
sum[k] += a * len[k]; | |
min_d[k].add(-a); nmin_d[k].add(-a); | |
min_d[k].sum -= a * min_c[k]; | |
nmin_d[k].sum -= a * (len[k] - min_c[k]); | |
ladd[k] += a; | |
} | |
void push(int k) { | |
if(k >= n0-1) return; | |
if(ladd[k] != 0) { | |
addall(2*k+1, ladd[k]); | |
addall(2*k+2, ladd[k]); | |
ladd[k] = 0; | |
} | |
if(min_v[2*k+1] < min_v[k]) { | |
update_node_min(2*k+1, min_v[k]); | |
} | |
if(min_v[2*k+2] < min_v[k]) { | |
update_node_min(2*k+2, min_v[k]); | |
} | |
// a_i の最小値に d_i の最小値情報を伝搬 | |
if(min_v[2*k+1] < min_v[2*k+2]) { | |
min_d[2*k+1].update_min(min_d[k].min_v); | |
min_d[2*k+2].update_min(nmin_d[k].min_v); | |
} else if(min_v[2*k+1] > min_v[2*k+2]) { | |
min_d[2*k+1].update_min(nmin_d[k].min_v); | |
min_d[2*k+2].update_min(min_d[k].min_v); | |
} else { | |
min_d[2*k+1].update_min(min_d[k].min_v); | |
min_d[2*k+2].update_min(min_d[k].min_v); | |
} | |
// a_i の非最小値に d_i の最小値情報を伝搬 | |
nmin_d[2*k+1].update_min(nmin_d[k].min_v); | |
nmin_d[2*k+2].update_min(nmin_d[k].min_v); | |
} | |
void update(int k) { | |
sum[k] = sum[2*k+1] + sum[2*k+2]; | |
nmin_d[k].merge(nmin_d[2*k+1], nmin_d[2*k+2]); | |
if(min_v[2*k+1] < min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1]; | |
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]); | |
min_d[k] = min_d[2*k+1]; | |
nmin_d[k].merge(nmin_d[k], min_d[2*k+2]); | |
} else if(min_v[2*k+1] > min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+2]; | |
min_c[k] = min_c[2*k+2]; | |
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]); | |
min_d[k] = min_d[2*k+2]; | |
nmin_d[k].merge(nmin_d[k], min_d[2*k+1]); | |
} else { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1] + min_c[2*k+2]; | |
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]); | |
min_d[k].merge(min_d[2*k+1], min_d[2*k+2]); | |
} | |
} | |
// (内部用) d_i <- max(d_i, 0) で更新 | |
void _update_dmax(int k, int l, int r) { | |
if(l == r || (0 <= min_d[k].min_v && 0 <= nmin_d[k].min_v)) { | |
return; | |
} | |
if(0 < min_d[k].smin_v && 0 < nmin_d[k].smin_v) { | |
min_d[k].update_min(0); | |
nmin_d[k].update_min(0); | |
return; | |
} | |
push(k); | |
_update_dmax(2*k+1, l, (l+r)/2); | |
_update_dmax(2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b) の a_i を max(a_i, x) で更新 | |
void _update_max(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || x <= min_v[k]) { | |
return; | |
} | |
if(a <= l && r <= b && x < smin_v[k]) { | |
update_node_min(k, x); | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_update_max(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_max(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b) の a_i に x を加算する | |
void _add_val(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
addall(k, x); | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_add_val(x, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// 区間[a, b) における a_i の区間最小値を返す | |
ll _query_min(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return inf; | |
} | |
if(a <= l && r <= b) { | |
return min_v[k]; | |
} | |
push(k); | |
ll lv = _query_min(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_min(a, b, 2*k+2, (l+r)/2, r); | |
return min(lv, rv); | |
} | |
// 区間[a, b) における historic maximal value の区間総和を返す | |
ll _query_hmax_sum(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
return sum[k] + min_d[k].sum + nmin_d[k].sum; | |
} | |
push(k); | |
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r); | |
return lv + rv; | |
} | |
// 区間[a, b) における historic maximal value の区間最大値を返す | |
ll _query_hmax_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
return max(min_d[k].hmax(), nmin_d[k].hmax()); | |
} | |
push(k); | |
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
public: | |
SegmentTree(int n, ll *a) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
len[0] = n0; | |
for(int i=0; i<2*n0-1; ++i) { | |
ladd[i] = 0; | |
len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
} | |
for(int i=0; i<n; ++i) { | |
min_v[n0-1+i] = sum[n0-1+i] = a[i]; | |
smin_v[n0-1+i] = inf; | |
min_c[n0-1+i] = 1; | |
min_d[n0-1+i].init(a[i]); | |
nmin_d[n0-1+i].init_empty(); | |
} | |
for(int i=n; i<n0; ++i) { | |
sum[n0-1+i] = min_v[n0-1+i] = 0; | |
smin_v[n0-1+i] = inf; | |
min_c[n0-1+i] = 0; | |
min_d[n0-1+i].init_empty(); | |
nmin_d[n0-1+i].init_empty(); | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// i<=i<r について a_i の値を max(a_i, x) に更新 | |
void update_max(int a, int b, ll x) { | |
_update_max(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r について a_i の値に x を加える | |
void add_val(int a, int b, ll x) { | |
_add_val(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の b_i の区間総和を求める | |
ll query_hmax_sum(int a, int b) { | |
return _query_hmax_sum(a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の b_i の区間最大値を求める | |
ll query_hmax_max(int a, int b) { | |
return _query_hmax_max(a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の a_i の区間最小値を求める | |
ll query_min(int a, int b) { | |
return _query_min(a, b, 0, 0, n0); | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include<algorithm> | |
using namespace std; | |
using ll = long long; | |
// Segment Tree Beats (Historic Information) | |
// - i<=i<r について a_i の値を max(a_i, x) に更新 | |
// - i<=i<r について a_i の値を min(a_i, x) に更新 | |
// - l<=i<r について a_i の値に x を加える | |
// - l<=i<r について a_i の値を x に更新 | |
// - l<=i<r の中の b_i の最大値を求める | |
// - l<=i<r の中の b_i の総和を求める | |
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i)) | |
#define N 5003 | |
class SegmentTree { | |
const static ll inf = 1e18; | |
struct HistVal { | |
ll min_v, smin_v, min_c, min_s; | |
ll m_hmax, nm_hmax; | |
// a_i = x でノードを初期化 | |
void init(ll x) { | |
min_v = 0; smin_v = inf; min_c = 1; | |
min_s = 0; | |
m_hmax = x; nm_hmax = -inf; | |
} | |
// 要素なしでノードを初期化 | |
void init_empty() { | |
min_v = smin_v = inf; min_c = 0; | |
min_s = 0; | |
m_hmax = nm_hmax = -inf; | |
} | |
inline void update_min(ll x) { | |
if(min_v < x) { | |
min_s += (x - min_v) * min_c; | |
m_hmax += (x - min_v); | |
min_v = x; | |
} | |
} | |
// l と r の情報をマージして更新 | |
void merge(HistVal &l, HistVal &r) { | |
min_s = l.min_s + r.min_s; | |
nm_hmax = max(l.nm_hmax, r.nm_hmax); | |
if(l.min_v < r.min_v) { | |
smin_v = min(l.smin_v, r.min_v); | |
min_v = l.min_v; | |
min_c = l.min_c; | |
nm_hmax = max(nm_hmax, r.m_hmax); | |
m_hmax = l.m_hmax; | |
} else if(l.min_v > r.min_v) { | |
smin_v = min(l.min_v, r.smin_v); | |
min_v = r.min_v; | |
min_c = r.min_c; | |
nm_hmax = max(nm_hmax, l.m_hmax); | |
m_hmax = r.m_hmax; | |
} else { | |
smin_v = min(l.smin_v, r.smin_v); | |
min_v = l.min_v; | |
min_c = l.min_c + r.min_c; | |
m_hmax = max(l.m_hmax, r.m_hmax); | |
} | |
} | |
// 要素c個にxを加算 | |
void add(ll x, ll c) { | |
if(min_v != inf) min_v += x; | |
if(smin_v != inf) smin_v += x; | |
min_s += x * c; | |
} | |
// a_i + d_i の最大値を返す | |
ll hmax() const { | |
return max(m_hmax, nm_hmax); | |
} | |
}; | |
int n0; | |
ll len[4*N], ladd[4*N]; | |
HistVal max_d[4*N], nval_d[4*N], min_d[4*N]; | |
ll max_v[4*N], smax_v[4*N], max_c[4*N]; | |
ll min_v[4*N], smin_v[4*N], min_c[4*N]; | |
ll sum[4*N]; | |
void update_node_max(int k, ll x) { | |
sum[k] += (x - max_v[k]) * max_c[k]; | |
max_d[k].add(max_v[k] - x, max_c[k]); | |
if(max_v[k] == min_v[k]) { | |
min_d[k].add(min_v[k] - x, min_c[k]); | |
max_v[k] = min_v[k] = x; | |
} else if(max_v[k] == smin_v[k]) { | |
max_v[k] = smin_v[k] = x; | |
} else { | |
max_v[k] = x; | |
} | |
} | |
void update_node_min(int k, ll x) { | |
sum[k] += (x - min_v[k]) * min_c[k]; | |
min_d[k].add(min_v[k] - x, min_c[k]); | |
if(min_v[k] == max_v[k]) { | |
max_d[k].add(max_v[k] - x, max_c[k]); | |
min_v[k] = max_v[k] = x; | |
} else if(min_v[k] == smax_v[k]) { | |
min_v[k] = smax_v[k] = x; | |
} else { | |
min_v[k] = x; | |
} | |
} | |
void addall(int k, ll a) { | |
sum[k] += a * len[k]; | |
max_v[k] += a; | |
if(smax_v[k] != -inf) smax_v[k] += a; | |
min_v[k] += a; | |
if(smin_v[k] != inf) smin_v[k] += a; | |
max_d[k].add(-a, max_c[k]); | |
if(max_v[k] != min_v[k]) { | |
nval_d[k].add(-a, len[k] - min_c[k] - max_c[k]); | |
} | |
min_d[k].add(-a, min_c[k]); | |
ladd[k] += a; | |
} | |
// ノードk のhistoric informationを ノードt へ伝搬 | |
void _push_hist(int t, int k) { | |
// a_i が最小値の d_i の最小値に対する伝搬 | |
if(min_v[k] == min_v[t]) { | |
min_d[t].update_min(min_d[k].min_v); | |
} else if(max_v[k] == min_v[t]) { | |
// min_v[t] == max_v[t] == max_v[k] | |
min_d[t].update_min(max_d[k].min_v); | |
} else { | |
min_d[t].update_min(nval_d[k].min_v); | |
} | |
// a_i が最大値の d_i の最小値に対する伝搬 | |
if(max_v[k] == max_v[t]) { | |
max_d[t].update_min(max_d[k].min_v); | |
} else if(min_v[k] == max_v[t]) { | |
// max_v[t] == min_v[t] == min_v[k] | |
max_d[t].update_min(min_d[k].min_v); | |
} else { | |
max_d[t].update_min(nval_d[k].min_v); | |
} | |
// a_i が非最大値・非最小値の d_i の最小値に対する伝搬 | |
nval_d[t].update_min(nval_d[k].min_v); | |
} | |
void push(int k) { | |
if(k >= n0-1) return; | |
if(ladd[k] != 0) { | |
addall(2*k+1, ladd[k]); | |
addall(2*k+2, ladd[k]); | |
ladd[k] = 0; | |
} | |
if(max_v[k] < max_v[2*k+1]) { | |
update_node_max(2*k+1, max_v[k]); | |
} | |
if(max_v[k] < max_v[2*k+2]) { | |
update_node_max(2*k+2, max_v[k]); | |
} | |
if(min_v[2*k+1] < min_v[k]) { | |
update_node_min(2*k+1, min_v[k]); | |
} | |
if(min_v[2*k+2] < min_v[k]) { | |
update_node_min(2*k+2, min_v[k]); | |
} | |
_push_hist(2*k+1, k); | |
_push_hist(2*k+2, k); | |
} | |
void update(int k) { | |
sum[k] = sum[2*k+1] + sum[2*k+2]; | |
nval_d[k].merge(nval_d[2*k+1], nval_d[2*k+2]); | |
// a_i の最大値情報のマージ | |
if(max_v[2*k+1] > max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1]; | |
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]); | |
max_d[k] = max_d[2*k+1]; | |
} else if(max_v[2*k+1] < max_v[2*k+2]) { | |
max_v[k] = max_v[2*k+2]; | |
max_c[k] = max_c[2*k+2]; | |
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]); | |
max_d[k] = max_d[2*k+2]; | |
} else { | |
max_v[k] = max_v[2*k+1]; | |
max_c[k] = max_c[2*k+1] + max_c[2*k+2]; | |
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]); | |
max_d[k].merge(max_d[2*k+1], max_d[2*k+2]); | |
} | |
// a_i の最小値情報のマージ | |
if(min_v[2*k+1] < min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1]; | |
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]); | |
min_d[k] = min_d[2*k+1]; | |
} else if(min_v[2*k+1] > min_v[2*k+2]) { | |
min_v[k] = min_v[2*k+2]; | |
min_c[k] = min_c[2*k+2]; | |
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]); | |
min_d[k] = min_d[2*k+2]; | |
} else { | |
min_v[k] = min_v[2*k+1]; | |
min_c[k] = min_c[2*k+1] + min_c[2*k+2]; | |
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]); | |
min_d[k].merge(min_d[2*k+1], min_d[2*k+2]); | |
} | |
// ノードk において ノード2*k+1 の最大値・最小値にならない情報をマージ | |
if(min_v[2*k+1] == max_v[2*k+1]) { | |
if(min_v[k] < min_v[2*k+1] && max_v[2*k+1] < max_v[k]) { | |
nval_d[k].merge(nval_d[k], max_d[2*k+1]); | |
} | |
} else { // min_v[2*k+1] < max_v[2*k+1] | |
if(max_v[2*k+1] < max_v[k]) { | |
nval_d[k].merge(nval_d[k], max_d[2*k+1]); | |
} | |
if(min_v[k] < min_v[2*k+1]) { | |
nval_d[k].merge(nval_d[k], min_d[2*k+1]); | |
} | |
} | |
// ノードk において ノード2*k+2 の最大値・最小値にならない情報をマージ | |
if(min_v[2*k+2] == max_v[2*k+2]) { | |
if(min_v[k] < min_v[2*k+2] && max_v[2*k+2] < max_v[k]) { | |
nval_d[k].merge(nval_d[k], max_d[2*k+2]); | |
} | |
} else { // min_v[2*k+2] < max_v[2*k+2] | |
if(max_v[2*k+2] < max_v[k]) { | |
nval_d[k].merge(nval_d[k], max_d[2*k+2]); | |
} | |
if(min_v[k] < min_v[2*k+2]) { | |
nval_d[k].merge(nval_d[k], min_d[2*k+2]); | |
} | |
} | |
} | |
// (内部用) d_i の range maximize query | |
void _update_dmax(int k, int l, int r) { | |
// a_i が最大値・最小値・それ以外で分けた d_i について d_i の最小値が 0 以上であれば終了 | |
if(l == r || (0 <= max_d[k].min_v && 0 <= nval_d[k].min_v && 0 <= min_d[k].min_v)) { | |
return; | |
} | |
// a_i が最大値・最小値・それ以外で分けた d_i について | |
// d_i の2番目の最小値が 0 より大きければ d_i の最大値を更新 | |
if(0 < max_d[k].smin_v && 0 < nval_d[k].smin_v && 0 < min_d[k].smin_v) { | |
max_d[k].update_min(0); | |
nval_d[k].update_min(0); | |
min_d[k].update_min(0); | |
return; | |
} | |
push(k); | |
_update_dmax(2*k+1, l, (l+r)/2); | |
_update_dmax(2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// a_i の range minimize query | |
void _update_min(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || max_v[k] <= x) { | |
return; | |
} | |
if(a <= l && r <= b && smax_v[k] < x) { | |
update_node_max(k, x); | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_update_min(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_min(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// a_i の range maximize query | |
void _update_max(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a || x <= min_v[k]) { | |
return; | |
} | |
if(a <= l && r <= b && x < smin_v[k]) { | |
update_node_min(k, x); | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_update_max(x, a, b, 2*k+1, l, (l+r)/2); | |
_update_max(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// a_i の range add query | |
void _add_val(ll x, int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return; | |
} | |
if(a <= l && r <= b) { | |
addall(k, x); | |
_update_dmax(k, l, r); | |
return; | |
} | |
push(k); | |
_add_val(x, a, b, 2*k+1, l, (l+r)/2); | |
_add_val(x, a, b, 2*k+2, (l+r)/2, r); | |
update(k); | |
} | |
// b_i の range maximum query | |
ll _query_hmax_max(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return -inf; | |
} | |
if(a <= l && r <= b) { | |
// a_i が最大値・最小値・それ以外のうちの a_i + d_i の最大値を返す | |
return max(max(max_d[k].hmax(), min_d[k].hmax()), nval_d[k].hmax()); | |
} | |
push(k); | |
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r); | |
return max(lv, rv); | |
} | |
// b_i の range sum query | |
ll _query_hmax_sum(int a, int b, int k, int l, int r) { | |
if(b <= l || r <= a) { | |
return 0; | |
} | |
if(a <= l && r <= b) { | |
if(max_v[k] == min_v[k]) { | |
// 同じ要素が a_i が最大値の要素、a_i が最小値の要素両方に含まれている | |
// 片方の総和のみ足して返す | |
return sum[k] + max_d[k].min_s; | |
} | |
// a_i の総和と a_i が最小値、最大値、それ以外の d_i の総和を合わせる | |
return sum[k] + max_d[k].min_s + nval_d[k].min_s + min_d[k].min_s; | |
} | |
push(k); | |
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2); | |
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r); | |
return lv + rv; | |
} | |
public: | |
SegmentTree(int n, ll *a) { | |
n0 = 1; | |
while(n0 < n) n0 <<= 1; | |
for(int i=0; i<2*n0-1; ++i) ladd[i] = 0; | |
len[0] = n0; | |
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1); | |
for(int i=0; i<n; ++i) { | |
max_v[n0-1+i] = min_v[n0-1+i] = sum[n0-1+i] = a[i]; | |
smax_v[n0-1+i] = -inf; smin_v[n0-1+i] = inf; | |
max_c[n0-1+i] = min_c[n0-1+i] = 1; | |
max_d[n0-1+i].init(a[i]); | |
nval_d[n0-1+i].init_empty(); | |
min_d[n0-1+i].init(a[i]); | |
} | |
for(int i=n; i<n0; ++i) { | |
sum[n0-1+i] = 0; | |
max_v[n0-1+i] = smax_v[n0-1+i] = -inf; | |
min_v[n0-1+i] = smin_v[n0-1+i] = inf; | |
max_c[n0-1+i] = min_c[n0-1+i] = 0; | |
max_d[n0-1+i].init_empty(); | |
nval_d[n0-1+i].init_empty(); | |
min_d[n0-1+i].init_empty(); | |
} | |
for(int i=n0-2; i>=0; i--) update(i); | |
} | |
// i<=i<r について A_i の値を max(A_i, x) に更新 | |
ll update_max(int a, int b, ll x) { | |
_update_max(x, a, b, 0, 0, n0); | |
} | |
// i<=i<r について A_i の値を min(A_i, x) に更新 | |
ll update_min(int a, int b, ll x) { | |
_update_min(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r について A_i の値に x を加える | |
void add_val(int a, int b, ll x) { | |
_add_val(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r について A_i の値を x に更新 | |
void update_val(int a, int b, ll x) { | |
_update_min(x, a, b, 0, 0, n0); | |
_update_max(x, a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の B_i の最大値を求める | |
ll query_hmax_max(int a, int b) { | |
return _query_hmax_max(a, b, 0, 0, n0); | |
} | |
// l<=i<r の中の B_i の総和を求める | |
ll query_hmax_sum(int a, int b) { | |
return _query_hmax_sum(a, b, 0, 0, n0); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment