Skip to content

Instantly share code, notes, and snippets.

@tjkendev
Last active June 5, 2019 16:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tjkendev/ad154c067a02648f2c2d8194abe6022d to your computer and use it in GitHub Desktop.
Save tjkendev/ad154c067a02648f2c2d8194abe6022d to your computer and use it in GitHub Desktop.
Segment Tree Beats (Historic Informationまわり) の実装
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats (Historic Information)
// - l<=i<r について a_i の値に x を加える
// - l<=i<r の中の a_i の最大値を求める
// - l<=i<r の中の b_i の総和を求める
// - l<=i<r の中の b_i の最大値を求める
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i))
#define N 10003
class SegmentTree {
static const ll inf = 1e18;
int n0;
// cur_s : a_i の区間総和
// cur_ma: a_i の区間最大値
ll cur_s[4*N], cur_ma[4*N];
// 区間加算クエリ用 (ladd: lazy tag, len: 区間要素数)
ll ladd[4*N], len[4*N];
// d_i の最小値まわりの情報を持つ構造体
struct HistVal {
// min_v: d_i の最小値
// smin_v: 二番目の最小値
// min_c: 最小値の個数
// sum: d_iの総和
ll min_v, smin_v, min_c, sum;
// m_hmax: d_i が最小値の a_i + d_i の最大値
// nm_hmax: d_i が非最小値の a_i + d_i の最大値
ll m_hmax, nm_hmax;
// a_i の初期値が x の初期化処理
void init(ll x) {
min_v = 0; smin_v = inf; min_c = 1; sum = 0;
m_hmax = x; nm_hmax = -inf;
}
// a_i を使わない場合の初期化処理
void init_empty() {
min_v = smin_v = inf; min_c = 0; sum = 0;
m_hmax = nm_hmax = -inf;
}
// d_i の最小値を x に更新
inline void update_min(ll x) {
if(min_v < x) {
sum += (x - min_v) * min_c;
// a_i + d_i の最大値はここのみで変化する
m_hmax += (x - min_v);
min_v = x;
}
}
// l と r の情報をマージして更新
// (片方に自身を指定して使っても問題ないようになっている)
inline void merge(HistVal &l, HistVal &r) {
sum = l.sum + r.sum;
nm_hmax = max(l.nm_hmax, r.nm_hmax);
if(l.min_v < r.min_v) {
smin_v = min(l.smin_v, r.min_v);
min_v = l.min_v;
min_c = l.min_c;
nm_hmax = max(nm_hmax, r.m_hmax);
m_hmax = l.m_hmax;
} else if(l.min_v > r.min_v) {
smin_v = min(l.min_v, r.smin_v);
min_v = r.min_v;
min_c = r.min_c;
nm_hmax = max(nm_hmax, l.m_hmax);
m_hmax = r.m_hmax;
} else {
min_v = l.min_v;
smin_v = min(l.smin_v, r.smin_v);
min_c = l.min_c + r.min_c;
m_hmax = max(l.m_hmax, r.m_hmax);
}
}
// 最小値・二番目の最小値への加算処理
void add(ll x) {
if(min_v != inf) min_v += x;
if(smin_v != inf) smin_v += x;
}
// a_i + d_i の最大値を返す
ll hmax() const {
return max(m_hmax, nm_hmax);
}
};
// 各ノードごとに d_i の最小値情報を持つ
HistVal val_d[4*N];
// ノードk に含まれる各 a_i に a を加算
void addall(int k, ll a) {
// a_i の区間総和に加算
cur_s[k] += a * len[k];
cur_ma[k] += a;
// a_i + d_i の最大値に加算されるが、この後相殺されるため加算は不要
//val_d[k].m_hmax += a; val_d[k].nm_hmax += a;
// d_i の最小値・二番目の最小値と区間総和に加算
val_d[k].add(-a);
val_d[k].sum -= a * len[k];
//val_d[k].m_hmax -= a; val_d[k].nm_hmax -= a;
ladd[k] += a;
}
void push(int k) {
if(k >= n0-1) return;
if(ladd[k] != 0) {
addall(2*k+1, ladd[k]);
addall(2*k+2, ladd[k]);
ladd[k] = 0;
}
val_d[2*k+1].update_min(val_d[k].min_v);
val_d[2*k+2].update_min(val_d[k].min_v);
}
void update(int k) {
cur_s[k] = cur_s[2*k+1] + cur_s[2*k+2];
cur_ma[k] = max(cur_ma[2*k+1], cur_ma[2*k+2]);
val_d[k].merge(val_d[2*k+1], val_d[2*k+2]);
}
// (内部用) d_i <- max(d_i, 0) で更新
void _update_dmax(int k, int l, int r) {
// break condition: d_i の最小値(val_d[k].min_v)が 0 以上
if(l == r || 0 <= val_d[k].min_v) {
return;
}
// tag condition: d_i の二番目の最小値(val_d[k].smin_v)が 0 より大きい
if(0 < val_d[k].smin_v) {
// d_i の最小値を 0 に更新
val_d[k].update_min(0);
return;
}
push(k);
_update_dmax(2*k+1, l, (l+r)/2);
_update_dmax(2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b) の a_i に x を加算する
void _add_val(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
// a_i に x を加算
addall(k, x);
// 区間[l, r) の d_i を max(d_i, 0) に更新
_update_dmax(k, l, r);
return;
}
push(k);
_add_val(x, a, b, 2*k+1, l, (l+r)/2);
_add_val(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b) における a_i の区間最大値を返す
ll _query_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
return cur_ma[k];
}
push(k);
ll lv = _query_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
// 区間[a, b) における historic maximal value の区間最大値を返す
ll _query_hmax_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
return val_d[k].hmax();
}
push(k);
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
// 区間[a, b) における historic maximal value の区間総和を返す
ll _query_hmax_sum(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
return cur_s[k] + val_d[k].sum;
}
push(k);
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r);
return lv + rv;
}
public:
SegmentTree(int n, ll *a) {
n0 = 1;
while(n0 < n) n0 <<= 1;
len[0] = n0;
for(int i=0; i<2*n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1);
for(int i=0; i<n; ++i) {
cur_s[n0-1+i] = cur_ma[n0-1+i] = a[i];
val_d[n0-1+i].init(a[i]);
}
for(int i=n; i<n0; ++i) {
cur_s[n0-1+i] = cur_ma[n0-1+i] = 0;
val_d[n0-1+i].init_empty();
}
for(int i=n0-2; i>=0; i--) update(i);
}
// l<=i<r について a_i の値を a_i + x に更新
void add_val(int a, int b, ll x) {
_add_val(x, a, b, 0, 0, n0);
}
// l<=i<r の中の b_i の区間総和
ll query_hmax_sum(int a, int b) {
return _query_hmax_sum(a, b, 0, 0, n0);
}
// l<=i<r の中の b_i の区間最大値
ll query_hmax_max(int a, int b) {
return _query_hmax_max(a, b, 0, 0, n0);
}
// l<=i<r の中の a_i の区間最大値
ll query_max(int a, int b) {
return _query_max(a, b, 0, 0, n0);
}
};
using namespace std;
using ll = long long;
// Segment Tree with Lazy Propagation
// - l<=i<r について a_i を a_i + x に更新
// - l<=i<r の s_i の区間総和を求める
#define N 10004
class SegmentTree {
ll n0;
ll va[4*N], vb[4*N];
ll ladd_a[4*N], ladd_b[4*N];
ll len[4*N];
void addall(int k, ll x, ll s) {
va[k] += x * len[k];
vb[k] += s * len[k];
ladd_a[k] += x;
ladd_b[k] += s;
}
void push(int k) {
addall(2*k+1, ladd_a[k], ladd_b[k]);
addall(2*k+2, ladd_a[k], ladd_b[k]);
ladd_a[k] = ladd_b[k] = 0;
}
void update(int k) {
va[k] = va[2*k+1] + va[2*k+2];
vb[k] = vb[2*k+1] + vb[2*k+2];
}
void _add_val(ll x, ll t, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
addall(k, x, -t*x);
return;
}
push(k);
_add_val(x, t, a, b, 2*k+1, l, (l+r)/2);
_add_val(x, t, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
ll _query_sum(ll t, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
return va[k] * t + vb[k];
}
push(k);
ll lv = _query_sum(t, a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_sum(t, a, b, 2*k+2, (l+r)/2, r);
update(k);
return lv + rv;
}
public:
SegmentTree(int n, ll *a) {
n0 = 1;
while(n0 < n) n0 <<= 1;
len[0] = n0;
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1);
for(int i=0; i<n; ++i) {
va[n0-1+i] = vb[n0-1+i] = a[i];
}
for(int i=n; i<n0; ++i) {
va[n0-1+i] = vb[n0-1+i] = 0;
}
for(int i=n0-2; i>=0; --i) update(i);
}
void add_val(int a, int b, ll x, ll t) {
_add_val(x, t, a, b, 0, 0, n0);
}
ll query_sum(int a, int b, ll t) {
return _query_sum(t, a, b, 0, 0, n0);
}
};
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats (Historic Information)
// - i<=i<r について a_i の値を max(a_i, x) に更新
// - l<=i<r について a_i の値に x を加える
// - l<=i<r の中の a_i の最小値を求める
// - l<=i<r の中の b_i の総和を求める
// - l<=i<r の中の b_i の最大値を求める
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i))
#define N 10003
class SegmentTree {
const static ll inf = 1e18;
int n0;
// 区間加算クエリ用 (ladd: lazy tag, len: 区間要素数)
ll len[4*N], ladd[4*N];
// a_i の最小値・二番目の最小値・最小値の個数
ll min_v[4*N], smin_v[4*N], min_c[4*N];
// a_i の区間総和
ll sum[4*N];
struct HistVal {
// min_v: d_i の最小値
// smin_v: 二番目の最小値
// min_c: 最小値の個数
// sum: d_iの総和
ll min_v, smin_v, min_c, sum;
// m_hmax: d_i が最小値の a_i + d_i の最大値
// nm_hmax: d_i が非最小値の a_i + d_i の最大値
ll m_hmax, nm_hmax;
// a_i = x でノードを初期化
void init(ll x) {
min_v = 0; smin_v = inf; min_c = 1; sum = 0;
m_hmax = x; nm_hmax = -inf;
}
// 要素なしでノードを初期化
void init_empty() {
min_v = smin_v = inf; min_c = 0; sum = 0;
m_hmax = nm_hmax = -inf;
}
// d_i の最小値を x に更新
// pushdown, chmaxの際に使う
inline void update_min(ll x) {
if(min_v < x) {
sum += (x - min_v) * min_c;
m_hmax += (x - min_v);
min_v = x;
}
}
// l と r の情報をマージして更新
// updateの際に使う
inline void merge(HistVal &l, HistVal &r) {
sum = l.sum + r.sum;
nm_hmax = max(l.nm_hmax, r.nm_hmax);
if(l.min_v < r.min_v) {
smin_v = min(l.smin_v, r.min_v);
min_v = l.min_v;
min_c = l.min_c;
nm_hmax = max(nm_hmax, r.m_hmax);
m_hmax = l.m_hmax;
} else if(l.min_v > r.min_v) {
smin_v = min(l.min_v, r.smin_v);
min_v = r.min_v;
min_c = r.min_c;
nm_hmax = max(nm_hmax, l.m_hmax);
m_hmax = r.m_hmax;
} else {
min_v = l.min_v;
smin_v = min(l.smin_v, r.smin_v);
min_c = l.min_c + r.min_c;
m_hmax = max(l.m_hmax, r.m_hmax);
}
}
// 要素にxを加算
void add(ll x) {
if(min_v != inf) min_v += x;
if(smin_v != inf) smin_v += x;
}
// a_i + d_i の最大値を返す
ll hmax() const {
return max(m_hmax, nm_hmax);
}
};
// min_d: a_i が最小値の i における d_i まわりの情報
// nmin_d: a_i が非最小値の i における d_i まわりの情報
HistVal min_d[4*N], nmin_d[4*N];
// a_i の最小値を x に更新する
void update_node_min(int k, ll x) {
sum[k] += (x - min_v[k]) * min_c[k];
min_d[k].add(min_v[k] - x);
min_d[k].sum += (min_v[k] - x) * min_c[k];
min_v[k] = x;
}
void addall(int k, ll a) {
min_v[k] += a;
if(smin_v[k] != inf) smin_v[k] += a;
sum[k] += a * len[k];
min_d[k].add(-a); nmin_d[k].add(-a);
min_d[k].sum -= a * min_c[k];
nmin_d[k].sum -= a * (len[k] - min_c[k]);
ladd[k] += a;
}
void push(int k) {
if(k >= n0-1) return;
if(ladd[k] != 0) {
addall(2*k+1, ladd[k]);
addall(2*k+2, ladd[k]);
ladd[k] = 0;
}
if(min_v[2*k+1] < min_v[k]) {
update_node_min(2*k+1, min_v[k]);
}
if(min_v[2*k+2] < min_v[k]) {
update_node_min(2*k+2, min_v[k]);
}
// a_i の最小値に d_i の最小値情報を伝搬
if(min_v[2*k+1] < min_v[2*k+2]) {
min_d[2*k+1].update_min(min_d[k].min_v);
min_d[2*k+2].update_min(nmin_d[k].min_v);
} else if(min_v[2*k+1] > min_v[2*k+2]) {
min_d[2*k+1].update_min(nmin_d[k].min_v);
min_d[2*k+2].update_min(min_d[k].min_v);
} else {
min_d[2*k+1].update_min(min_d[k].min_v);
min_d[2*k+2].update_min(min_d[k].min_v);
}
// a_i の非最小値に d_i の最小値情報を伝搬
nmin_d[2*k+1].update_min(nmin_d[k].min_v);
nmin_d[2*k+2].update_min(nmin_d[k].min_v);
}
void update(int k) {
sum[k] = sum[2*k+1] + sum[2*k+2];
nmin_d[k].merge(nmin_d[2*k+1], nmin_d[2*k+2]);
if(min_v[2*k+1] < min_v[2*k+2]) {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1];
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]);
min_d[k] = min_d[2*k+1];
nmin_d[k].merge(nmin_d[k], min_d[2*k+2]);
} else if(min_v[2*k+1] > min_v[2*k+2]) {
min_v[k] = min_v[2*k+2];
min_c[k] = min_c[2*k+2];
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]);
min_d[k] = min_d[2*k+2];
nmin_d[k].merge(nmin_d[k], min_d[2*k+1]);
} else {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1] + min_c[2*k+2];
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]);
min_d[k].merge(min_d[2*k+1], min_d[2*k+2]);
}
}
// (内部用) d_i <- max(d_i, 0) で更新
void _update_dmax(int k, int l, int r) {
if(l == r || (0 <= min_d[k].min_v && 0 <= nmin_d[k].min_v)) {
return;
}
if(0 < min_d[k].smin_v && 0 < nmin_d[k].smin_v) {
min_d[k].update_min(0);
nmin_d[k].update_min(0);
return;
}
push(k);
_update_dmax(2*k+1, l, (l+r)/2);
_update_dmax(2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b) の a_i を max(a_i, x) で更新
void _update_max(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || x <= min_v[k]) {
return;
}
if(a <= l && r <= b && x < smin_v[k]) {
update_node_min(k, x);
_update_dmax(k, l, r);
return;
}
push(k);
_update_max(x, a, b, 2*k+1, l, (l+r)/2);
_update_max(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b) の a_i に x を加算する
void _add_val(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
addall(k, x);
_update_dmax(k, l, r);
return;
}
push(k);
_add_val(x, a, b, 2*k+1, l, (l+r)/2);
_add_val(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// 区間[a, b) における a_i の区間最小値を返す
ll _query_min(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return inf;
}
if(a <= l && r <= b) {
return min_v[k];
}
push(k);
ll lv = _query_min(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_min(a, b, 2*k+2, (l+r)/2, r);
return min(lv, rv);
}
// 区間[a, b) における historic maximal value の区間総和を返す
ll _query_hmax_sum(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
return sum[k] + min_d[k].sum + nmin_d[k].sum;
}
push(k);
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r);
return lv + rv;
}
// 区間[a, b) における historic maximal value の区間最大値を返す
ll _query_hmax_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
return max(min_d[k].hmax(), nmin_d[k].hmax());
}
push(k);
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
public:
SegmentTree(int n, ll *a) {
n0 = 1;
while(n0 < n) n0 <<= 1;
len[0] = n0;
for(int i=0; i<2*n0-1; ++i) {
ladd[i] = 0;
len[2*i+1] = len[2*i+2] = (len[i] >> 1);
}
for(int i=0; i<n; ++i) {
min_v[n0-1+i] = sum[n0-1+i] = a[i];
smin_v[n0-1+i] = inf;
min_c[n0-1+i] = 1;
min_d[n0-1+i].init(a[i]);
nmin_d[n0-1+i].init_empty();
}
for(int i=n; i<n0; ++i) {
sum[n0-1+i] = min_v[n0-1+i] = 0;
smin_v[n0-1+i] = inf;
min_c[n0-1+i] = 0;
min_d[n0-1+i].init_empty();
nmin_d[n0-1+i].init_empty();
}
for(int i=n0-2; i>=0; i--) update(i);
}
// i<=i<r について a_i の値を max(a_i, x) に更新
void update_max(int a, int b, ll x) {
_update_max(x, a, b, 0, 0, n0);
}
// l<=i<r について a_i の値に x を加える
void add_val(int a, int b, ll x) {
_add_val(x, a, b, 0, 0, n0);
}
// l<=i<r の中の b_i の区間総和を求める
ll query_hmax_sum(int a, int b) {
return _query_hmax_sum(a, b, 0, 0, n0);
}
// l<=i<r の中の b_i の区間最大値を求める
ll query_hmax_max(int a, int b) {
return _query_hmax_max(a, b, 0, 0, n0);
}
// l<=i<r の中の a_i の区間最小値を求める
ll query_min(int a, int b) {
return _query_min(a, b, 0, 0, n0);
}
};
#include<algorithm>
using namespace std;
using ll = long long;
// Segment Tree Beats (Historic Information)
// - i<=i<r について a_i の値を max(a_i, x) に更新
// - i<=i<r について a_i の値を min(a_i, x) に更新
// - l<=i<r について a_i の値に x を加える
// - l<=i<r について a_i の値を x に更新
// - l<=i<r の中の b_i の最大値を求める
// - l<=i<r の中の b_i の総和を求める
// - (各クエリ後、全てのiについて b_i = max(a_i, b_i))
#define N 5003
class SegmentTree {
const static ll inf = 1e18;
struct HistVal {
ll min_v, smin_v, min_c, min_s;
ll m_hmax, nm_hmax;
// a_i = x でノードを初期化
void init(ll x) {
min_v = 0; smin_v = inf; min_c = 1;
min_s = 0;
m_hmax = x; nm_hmax = -inf;
}
// 要素なしでノードを初期化
void init_empty() {
min_v = smin_v = inf; min_c = 0;
min_s = 0;
m_hmax = nm_hmax = -inf;
}
inline void update_min(ll x) {
if(min_v < x) {
min_s += (x - min_v) * min_c;
m_hmax += (x - min_v);
min_v = x;
}
}
// l と r の情報をマージして更新
void merge(HistVal &l, HistVal &r) {
min_s = l.min_s + r.min_s;
nm_hmax = max(l.nm_hmax, r.nm_hmax);
if(l.min_v < r.min_v) {
smin_v = min(l.smin_v, r.min_v);
min_v = l.min_v;
min_c = l.min_c;
nm_hmax = max(nm_hmax, r.m_hmax);
m_hmax = l.m_hmax;
} else if(l.min_v > r.min_v) {
smin_v = min(l.min_v, r.smin_v);
min_v = r.min_v;
min_c = r.min_c;
nm_hmax = max(nm_hmax, l.m_hmax);
m_hmax = r.m_hmax;
} else {
smin_v = min(l.smin_v, r.smin_v);
min_v = l.min_v;
min_c = l.min_c + r.min_c;
m_hmax = max(l.m_hmax, r.m_hmax);
}
}
// 要素c個にxを加算
void add(ll x, ll c) {
if(min_v != inf) min_v += x;
if(smin_v != inf) smin_v += x;
min_s += x * c;
}
// a_i + d_i の最大値を返す
ll hmax() const {
return max(m_hmax, nm_hmax);
}
};
int n0;
ll len[4*N], ladd[4*N];
HistVal max_d[4*N], nval_d[4*N], min_d[4*N];
ll max_v[4*N], smax_v[4*N], max_c[4*N];
ll min_v[4*N], smin_v[4*N], min_c[4*N];
ll sum[4*N];
void update_node_max(int k, ll x) {
sum[k] += (x - max_v[k]) * max_c[k];
max_d[k].add(max_v[k] - x, max_c[k]);
if(max_v[k] == min_v[k]) {
min_d[k].add(min_v[k] - x, min_c[k]);
max_v[k] = min_v[k] = x;
} else if(max_v[k] == smin_v[k]) {
max_v[k] = smin_v[k] = x;
} else {
max_v[k] = x;
}
}
void update_node_min(int k, ll x) {
sum[k] += (x - min_v[k]) * min_c[k];
min_d[k].add(min_v[k] - x, min_c[k]);
if(min_v[k] == max_v[k]) {
max_d[k].add(max_v[k] - x, max_c[k]);
min_v[k] = max_v[k] = x;
} else if(min_v[k] == smax_v[k]) {
min_v[k] = smax_v[k] = x;
} else {
min_v[k] = x;
}
}
void addall(int k, ll a) {
sum[k] += a * len[k];
max_v[k] += a;
if(smax_v[k] != -inf) smax_v[k] += a;
min_v[k] += a;
if(smin_v[k] != inf) smin_v[k] += a;
max_d[k].add(-a, max_c[k]);
if(max_v[k] != min_v[k]) {
nval_d[k].add(-a, len[k] - min_c[k] - max_c[k]);
}
min_d[k].add(-a, min_c[k]);
ladd[k] += a;
}
// ノードk のhistoric informationを ノードt へ伝搬
void _push_hist(int t, int k) {
// a_i が最小値の d_i の最小値に対する伝搬
if(min_v[k] == min_v[t]) {
min_d[t].update_min(min_d[k].min_v);
} else if(max_v[k] == min_v[t]) {
// min_v[t] == max_v[t] == max_v[k]
min_d[t].update_min(max_d[k].min_v);
} else {
min_d[t].update_min(nval_d[k].min_v);
}
// a_i が最大値の d_i の最小値に対する伝搬
if(max_v[k] == max_v[t]) {
max_d[t].update_min(max_d[k].min_v);
} else if(min_v[k] == max_v[t]) {
// max_v[t] == min_v[t] == min_v[k]
max_d[t].update_min(min_d[k].min_v);
} else {
max_d[t].update_min(nval_d[k].min_v);
}
// a_i が非最大値・非最小値の d_i の最小値に対する伝搬
nval_d[t].update_min(nval_d[k].min_v);
}
void push(int k) {
if(k >= n0-1) return;
if(ladd[k] != 0) {
addall(2*k+1, ladd[k]);
addall(2*k+2, ladd[k]);
ladd[k] = 0;
}
if(max_v[k] < max_v[2*k+1]) {
update_node_max(2*k+1, max_v[k]);
}
if(max_v[k] < max_v[2*k+2]) {
update_node_max(2*k+2, max_v[k]);
}
if(min_v[2*k+1] < min_v[k]) {
update_node_min(2*k+1, min_v[k]);
}
if(min_v[2*k+2] < min_v[k]) {
update_node_min(2*k+2, min_v[k]);
}
_push_hist(2*k+1, k);
_push_hist(2*k+2, k);
}
void update(int k) {
sum[k] = sum[2*k+1] + sum[2*k+2];
nval_d[k].merge(nval_d[2*k+1], nval_d[2*k+2]);
// a_i の最大値情報のマージ
if(max_v[2*k+1] > max_v[2*k+2]) {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1];
smax_v[k] = max(smax_v[2*k+1], max_v[2*k+2]);
max_d[k] = max_d[2*k+1];
} else if(max_v[2*k+1] < max_v[2*k+2]) {
max_v[k] = max_v[2*k+2];
max_c[k] = max_c[2*k+2];
smax_v[k] = max(max_v[2*k+1], smax_v[2*k+2]);
max_d[k] = max_d[2*k+2];
} else {
max_v[k] = max_v[2*k+1];
max_c[k] = max_c[2*k+1] + max_c[2*k+2];
smax_v[k] = max(smax_v[2*k+1], smax_v[2*k+2]);
max_d[k].merge(max_d[2*k+1], max_d[2*k+2]);
}
// a_i の最小値情報のマージ
if(min_v[2*k+1] < min_v[2*k+2]) {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1];
smin_v[k] = min(smin_v[2*k+1], min_v[2*k+2]);
min_d[k] = min_d[2*k+1];
} else if(min_v[2*k+1] > min_v[2*k+2]) {
min_v[k] = min_v[2*k+2];
min_c[k] = min_c[2*k+2];
smin_v[k] = min(min_v[2*k+1], smin_v[2*k+2]);
min_d[k] = min_d[2*k+2];
} else {
min_v[k] = min_v[2*k+1];
min_c[k] = min_c[2*k+1] + min_c[2*k+2];
smin_v[k] = min(smin_v[2*k+1], smin_v[2*k+2]);
min_d[k].merge(min_d[2*k+1], min_d[2*k+2]);
}
// ノードk において ノード2*k+1 の最大値・最小値にならない情報をマージ
if(min_v[2*k+1] == max_v[2*k+1]) {
if(min_v[k] < min_v[2*k+1] && max_v[2*k+1] < max_v[k]) {
nval_d[k].merge(nval_d[k], max_d[2*k+1]);
}
} else { // min_v[2*k+1] < max_v[2*k+1]
if(max_v[2*k+1] < max_v[k]) {
nval_d[k].merge(nval_d[k], max_d[2*k+1]);
}
if(min_v[k] < min_v[2*k+1]) {
nval_d[k].merge(nval_d[k], min_d[2*k+1]);
}
}
// ノードk において ノード2*k+2 の最大値・最小値にならない情報をマージ
if(min_v[2*k+2] == max_v[2*k+2]) {
if(min_v[k] < min_v[2*k+2] && max_v[2*k+2] < max_v[k]) {
nval_d[k].merge(nval_d[k], max_d[2*k+2]);
}
} else { // min_v[2*k+2] < max_v[2*k+2]
if(max_v[2*k+2] < max_v[k]) {
nval_d[k].merge(nval_d[k], max_d[2*k+2]);
}
if(min_v[k] < min_v[2*k+2]) {
nval_d[k].merge(nval_d[k], min_d[2*k+2]);
}
}
}
// (内部用) d_i の range maximize query
void _update_dmax(int k, int l, int r) {
// a_i が最大値・最小値・それ以外で分けた d_i について d_i の最小値が 0 以上であれば終了
if(l == r || (0 <= max_d[k].min_v && 0 <= nval_d[k].min_v && 0 <= min_d[k].min_v)) {
return;
}
// a_i が最大値・最小値・それ以外で分けた d_i について
// d_i の2番目の最小値が 0 より大きければ d_i の最大値を更新
if(0 < max_d[k].smin_v && 0 < nval_d[k].smin_v && 0 < min_d[k].smin_v) {
max_d[k].update_min(0);
nval_d[k].update_min(0);
min_d[k].update_min(0);
return;
}
push(k);
_update_dmax(2*k+1, l, (l+r)/2);
_update_dmax(2*k+2, (l+r)/2, r);
update(k);
}
// a_i の range minimize query
void _update_min(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || max_v[k] <= x) {
return;
}
if(a <= l && r <= b && smax_v[k] < x) {
update_node_max(k, x);
_update_dmax(k, l, r);
return;
}
push(k);
_update_min(x, a, b, 2*k+1, l, (l+r)/2);
_update_min(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// a_i の range maximize query
void _update_max(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a || x <= min_v[k]) {
return;
}
if(a <= l && r <= b && x < smin_v[k]) {
update_node_min(k, x);
_update_dmax(k, l, r);
return;
}
push(k);
_update_max(x, a, b, 2*k+1, l, (l+r)/2);
_update_max(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// a_i の range add query
void _add_val(ll x, int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return;
}
if(a <= l && r <= b) {
addall(k, x);
_update_dmax(k, l, r);
return;
}
push(k);
_add_val(x, a, b, 2*k+1, l, (l+r)/2);
_add_val(x, a, b, 2*k+2, (l+r)/2, r);
update(k);
}
// b_i の range maximum query
ll _query_hmax_max(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return -inf;
}
if(a <= l && r <= b) {
// a_i が最大値・最小値・それ以外のうちの a_i + d_i の最大値を返す
return max(max(max_d[k].hmax(), min_d[k].hmax()), nval_d[k].hmax());
}
push(k);
ll lv = _query_hmax_max(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_max(a, b, 2*k+2, (l+r)/2, r);
return max(lv, rv);
}
// b_i の range sum query
ll _query_hmax_sum(int a, int b, int k, int l, int r) {
if(b <= l || r <= a) {
return 0;
}
if(a <= l && r <= b) {
if(max_v[k] == min_v[k]) {
// 同じ要素が a_i が最大値の要素、a_i が最小値の要素両方に含まれている
// 片方の総和のみ足して返す
return sum[k] + max_d[k].min_s;
}
// a_i の総和と a_i が最小値、最大値、それ以外の d_i の総和を合わせる
return sum[k] + max_d[k].min_s + nval_d[k].min_s + min_d[k].min_s;
}
push(k);
ll lv = _query_hmax_sum(a, b, 2*k+1, l, (l+r)/2);
ll rv = _query_hmax_sum(a, b, 2*k+2, (l+r)/2, r);
return lv + rv;
}
public:
SegmentTree(int n, ll *a) {
n0 = 1;
while(n0 < n) n0 <<= 1;
for(int i=0; i<2*n0-1; ++i) ladd[i] = 0;
len[0] = n0;
for(int i=0; i<n0-1; ++i) len[2*i+1] = len[2*i+2] = (len[i] >> 1);
for(int i=0; i<n; ++i) {
max_v[n0-1+i] = min_v[n0-1+i] = sum[n0-1+i] = a[i];
smax_v[n0-1+i] = -inf; smin_v[n0-1+i] = inf;
max_c[n0-1+i] = min_c[n0-1+i] = 1;
max_d[n0-1+i].init(a[i]);
nval_d[n0-1+i].init_empty();
min_d[n0-1+i].init(a[i]);
}
for(int i=n; i<n0; ++i) {
sum[n0-1+i] = 0;
max_v[n0-1+i] = smax_v[n0-1+i] = -inf;
min_v[n0-1+i] = smin_v[n0-1+i] = inf;
max_c[n0-1+i] = min_c[n0-1+i] = 0;
max_d[n0-1+i].init_empty();
nval_d[n0-1+i].init_empty();
min_d[n0-1+i].init_empty();
}
for(int i=n0-2; i>=0; i--) update(i);
}
// i<=i<r について A_i の値を max(A_i, x) に更新
ll update_max(int a, int b, ll x) {
_update_max(x, a, b, 0, 0, n0);
}
// i<=i<r について A_i の値を min(A_i, x) に更新
ll update_min(int a, int b, ll x) {
_update_min(x, a, b, 0, 0, n0);
}
// l<=i<r について A_i の値に x を加える
void add_val(int a, int b, ll x) {
_add_val(x, a, b, 0, 0, n0);
}
// l<=i<r について A_i の値を x に更新
void update_val(int a, int b, ll x) {
_update_min(x, a, b, 0, 0, n0);
_update_max(x, a, b, 0, 0, n0);
}
// l<=i<r の中の B_i の最大値を求める
ll query_hmax_max(int a, int b) {
return _query_hmax_max(a, b, 0, 0, n0);
}
// l<=i<r の中の B_i の総和を求める
ll query_hmax_sum(int a, int b) {
return _query_hmax_sum(a, b, 0, 0, n0);
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment