Skip to content

Instantly share code, notes, and snippets.

@Hegdahl
Last active May 10, 2022 21:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hegdahl/f582ad0446db5c10b623073322e037c9 to your computer and use it in GitHub Desktop.
Save Hegdahl/f582ad0446db5c10b623073322e037c9 to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
using namespace std;
template <class T>
struct CommutativeSegmentTree {
const int offset;
vector<T> values;
static constexpr int round_up_to_power_of_two(int n) {
return 2 << __lg((n - 1) | 1);
}
CommutativeSegmentTree(int n)
: offset(round_up_to_power_of_two(n)), values(2 * offset) {}
template <class F>
void rootpath(int p, F &&f) {
p += offset;
do {
f(values[p]);
} while (p /= 2);
}
template <class F>
void rangecover(int l, int r, F &&f) {
l += offset - 1;
r += offset + 1;
while (l + 1 < r) {
if (l % 2 == 0) f(values[l + 1]);
if (r % 2 == 1) f(values[r - 1]);
l /= 2;
r /= 2;
}
}
};
// would usually use a more generic template,
// but since it's not the focus I will not
// put a big merging segment tree template in
// the example code.
struct MinSegmentTree {
const int offset;
vector<pair<int, int>> values;
static constexpr int round_up_to_power_of_two(int n) {
return 2 << __lg((n - 1) | 1);
}
MinSegmentTree(int n)
: offset(round_up_to_power_of_two(n)), values(2 * offset, {(int)1e9, (int)1e9}) {}
void point_set(int p, pair<int, int> v) {
p += offset;
values[p] = v;
while (p /= 2) {
values[p] = min(values[2*p], values[2*p+1]);
}
}
pair<int, int> range_query(int l, int r) {
l += offset - 1;
r += offset + 1;
pair<int, int> ans = {(int)1e9, (int)1e9};
while (l + 1 < r) {
if (l % 2 == 0) ans = min(ans, values[l + 1]);
if (r % 2 == 1) ans = min(ans, values[r - 1]);
l /= 2;
r /= 2;
}
return ans;
}
};
template <class K, class V>
struct compressed_fenwick_tree {
vector<K> keys;
vector<V> s;
void _update(int i, V dif) {
for (; i < (int)s.size(); i |= i + 1) s[i] += dif;
}
V _query(int i) {
V res = 0;
for (; i > 0; i &= i - 1) res += s[i - 1];
return res;
}
int _get_index(K key) {
return int(lower_bound(keys.begin(), keys.end(), key) - keys.begin());
}
void prepare(K key) { keys.push_back(key); }
void init() {
sort(keys.begin(), keys.end());
keys.erase(unique(keys.begin(), keys.end()), keys.end());
s.resize(keys.size());
}
// a[key] += dif
void update(K key, V dif) { _update(_get_index(key), dif); }
// sum(a[i] for i in (-inf, key))
V query(K key) { return _query(_get_index(key)); }
};
constexpr int mxN = 2e5;
vector<int> g[mxN];
int node_first_x[mxN], node_last_x[mxN], node_y[mxN], who[mxN];
void dfs(int cur, int prv, int &t) {
who[t] = cur;
node_first_x[cur] = t++;
for (int nxt : g[cur]) {
if (nxt == prv) continue;
node_y[nxt] = node_y[cur] + 1;
dfs(nxt, cur, t);
}
node_last_x[cur] = t - 1;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int n;
cin >> n;
for (int nn = 0; nn < n - 1; ++nn) {
int i, j;
cin >> i >> j;
--i, --j;
g[i].push_back(j);
g[j].push_back(i);
}
{
int t = 0;
dfs(0, -1, t);
}
MinSegmentTree st_0(n);
CommutativeSegmentTree<compressed_fenwick_tree<int, int>> st_1(n);
auto prepare = [&](int i) {
int x = node_first_x[i];
int y = node_y[i];
st_1.rootpath(x, [&](auto &s) {
s.prepare(y);
});
};
auto insert1 = [&](int i) {
int x = node_first_x[i];
int y = node_y[i];
st_1.rootpath(x, [&](auto &s) {
s.update(y, 1);
});
};
auto count1s = [&](int i, int k) {
int x0 = node_first_x[i];
int x1 = node_last_x[i];
int y0 = node_y[i];
int y1 = min(y0 + k, n - 1);
int result = 0;
st_1.rangecover(x0, x1, [&](auto &s) {
result += s.query(y1+1);
});
return result;
};
for (int i = 0; i < n; ++i) prepare(i);
for (auto &s : st_1.values) s.init();
for (int i = 0; i < n; ++i) {
int value;
cin >> value;
if (value)
insert1(i);
else
st_0.point_set(node_first_x[i], {node_y[i], node_first_x[i]});
}
int q;
cin >> q;
while (q--) {
int t, i, k;
cin >> t >> i >> k;
--i;
if (t == 1) {
int x0 = node_first_x[i];
int x1 = node_last_x[i];
int y1 = node_y[i] + k;
pair<int, int> argmin;
while (argmin = st_0.range_query(x0, x1), argmin.first <= y1) {
auto [y, x] = argmin;
st_0.point_set(x, {(int)1e9, (int)1e9});
insert1(who[x]);
}
} else if (t == 2) {
cout << count1s(i, k) << '\n';
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment