Skip to content

Instantly share code, notes, and snippets.

@meooow25
Created February 26, 2019 15:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save meooow25/e31bcd754c07e94187fc86811a0752b8 to your computer and use it in GitHub Desktop.
Save meooow25/e31bcd754c07e94187fc86811a0752b8 to your computer and use it in GitHub Desktop.
// CLTREE
// Tester: sarthakmanna
import java.io.*;
import java.util.*;
class CP {
public static void main(String[] args) throws Exception {
/*new Thread(null, new Runnable() {
@Override
public void run() {
try {
new Solver().solve();
} catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
}
}, "Solver", 1l << 30).start();*/
new Solver().solve();
}
}
class Solver {
IO io = new IO(System.in, System.out);
int N, X, Y, Z;
ArrayList<Integer>[] graph;
void solve() throws Exception {
int i, j, k;
for (int tc = io.nextInt(); tc > 0; --tc) {
N = io.nextInt(); X= io.nextInt(); Y = io.nextInt(); Z = io.nextInt();
graph = new ArrayList[N];
for (i = 0; i < N; ++i) graph[i] = new ArrayList<>();
for (i = 0; i < N - 1; ++i) {
int a = io.nextInt() - 1, b = io.nextInt() - 1;
graph[a].add(b); graph[b].add(a);
}
ar = new int[N]; in = new int[N]; out = new int[N];
K = 0; dfs(0, -7, 7);
st = new SegmentTree(ar);
A = X + Y - Z; B = X - Y + Z; C = -X + Y + Z;
if (A % 2 > 0 || B % 2 > 0 || C % 2 > 0) {
io.println(0);
continue;
}
A >>= 1; B >>= 1; C >>= 1;
ArrayList<int[]>[] queries = new ArrayList[N];
for (i = 0; i < N; ++i) {
j = 3 * i;
queries[i] = new ArrayList<>();
queries[i].add(new int[] {A - 1, j + 0});
queries[i].add(new int[] {B - 1, j + 1});
queries[i].add(new int[] {C - 1, j + 2});
}
cd = new CentroidDecomposition(N, graph, queries);
counts = cd.decompose(3 * N);
//System.out.println(Arrays.toString(counts));
ans = 0;
eulerTour(0, -7);
io.println(ans);
}
io.flush();
}
int[] ar, in, out, counts;
int A, B, C, K;
long ans;
SegmentTree st;
CentroidDecomposition cd;
void dfs(int node, int par, int dep) {
in[node] = K; ar[K++] = dep++;
for (int itr : graph[node]) if (itr != par) dfs(itr, node, dep);
out[node] = K - 1;
}
void eulerTour(int node, int prev) {
int i, j, k;
long[] dp = new long[8];
dp[0] = 1;
if (A == 0) dp[1 << 0] = 1;
if (B == 0) dp[1 << 1] = 1;
if (C == 0) dp[1 << 2] = 1;
for (int itr : graph[node]) {
long[] newDP = new long[8];
System.arraycopy(dp, 0, newDP, 0, 8);
int[] count;
if (itr == prev) {
count = new int[] {allAround(itr, A - 1) - countUnder(node, A - 2),
allAround(itr, B - 1) - countUnder(node, B - 2),
allAround(itr, C - 1) - countUnder(node, C - 2)};
} else {
count = new int[] {countUnder(itr, A - 1),
countUnder(itr, B - 1), countUnder(itr, C - 1)};
}
for (j = 0; j < 3; ++j) for (k = 0; k < 8; ++k) if (((k >> j) & 1) == 0)
newDP[(1 << j) | k] += count[j] * dp[k];
dp = newDP;
}
ans += dp[7];
for (int itr : graph[node]) if (itr != prev) eulerTour(itr, node);
}
int countUnder(int node, int dist) {
return st.rangeQuery(0, 0, st.N - 1, in[node], out[node], ar[in[node]] + dist);
}
int allAround(int node, int d) {
node *= 3;
if (d == A - 1) return counts[node + 0];
else if (d == B - 1) return counts[node + 1];
else if (d == C - 1) return counts[node + 2];
else return 1/0;
}
}
class SegmentTree {
myArrayList[] tree;
int N;
SegmentTree(int[] ar) {
N = 1; while (N < ar.length) N <<= 1;
tree = new myArrayList[(N << 1) - 1];
for (int i = 0; i < N; ++i) tree[N - 1 + i] = new myArrayList(1);
for (int i = 0; i < ar.length; ++i) tree[N - 1 + i].add(ar[i]);
for (int i = N - 2; i >= 0; --i) tree[i] = merge(tree[(i << 1) + 1], tree[(i << 1) + 2]);
}
myArrayList merge(myArrayList m1, myArrayList m2) {
myArrayList ret = new myArrayList(m1.size + m2.size);
int p1 = 0, p2 = 0;
while (p1 < m1.size || p2 < m2.size) {
if (p1 == m1.size) ret.add(m2.get(p2++));
else if (p2 == m2.size) ret.add(m1.get(p1++));
else if (m1.get(p1) < m2.get(p2)) ret.add(m1.get(p1++));
else ret.add(m2.get(p2++));
}
ret.sort();
return ret;
}
int rangeQuery(int i, int l, int r, int ql, int qr, int key) {
int mid = (l + r) >> 1, i2 = i << 1;
if (l > qr || r < ql) return 0;
else if (l >= ql && r <= qr) return tree[i].count(key);
else return rangeQuery(i2 + 1, l, mid, ql, qr, key)
+ rangeQuery(i2 + 2, mid + 1, r, ql, qr, key);
}
}
class CentroidDecomposition {
int N;
ArrayList<Integer>[] graph;
ArrayList<int[]>[] queries;
CentroidDecomposition(int n, ArrayList<Integer>[] g, ArrayList<int[]>[] q) {
N = n; graph = g; queries = q;
size = new myHashMap(N + 7);
count = new myHashMap(N + 7);
}
// https://codeforces.com/blog/entry/58025
myHashMap size, count;
void setSize(int node, int par, int root, boolean[] dead) {
size.put(node, 1);
for (int itr : graph[node]) if (itr != par && !dead[itr]) {
setSize(itr, node, root, dead);
size.put(node, size.getOrDefault(node, 0) + size.getOrDefault(itr, 0));
}
}
int dfs(int node, int par, int n, boolean[] dead) {
for (int itr : graph[node]) if (itr != par && !dead[itr]) {
if (size.getOrDefault(itr, 0) > n >> 1)
return dfs(itr, node, n, dead);
}
return node;
}
int oneCentroid(int root, boolean[] dead) {
setSize(root, -7, root, dead);
return dfs(root, -7, size.getOrDefault(root, 0), dead);
}
void rec(int start, boolean[] dead) {
int c = oneCentroid(start, dead);
dead[c] = true;
for (int itr : graph[c]) if (!dead[itr])
rec(itr, dead);
count.clear();
addCount(c, -7, 0, true, count, dead);
for (int[] itr : queries[c]) {
int dd = itr[0], idx = itr[1];
ans[idx] += count.getOrDefault(dd, 0);
}
for (Integer itr : graph[c]) if (!dead[itr]) {
addCount(itr, c, 1, false, count, dead);
calc(itr, c, 1, count, dead);
addCount(itr, c, 1, true, count, dead);
}
dead[c] = false;
}
void addCount(int node, int prev, int d, boolean add, myHashMap count, boolean[] dead) {
count.put(d, count.getOrDefault(d, 0) + (add ? 1 : -1));
for (int itr : graph[node]) if (itr != prev && !dead[itr]) {
addCount(itr, node, d + 1, add, count, dead);
}
}
void calc(int node, int prev, int d, myHashMap count, boolean[] dead) {
for (int[] itr : queries[node]) {
int dd = itr[0], idx = itr[1];
if (dd >= d && count.containsKey(dd - d))
ans[idx] += count.get(dd - d);
}
for (int itr : graph[node]) if (itr != prev && !dead[itr]) {
calc(itr, node, d + 1, count, dead);
}
}
int[] ans;
int[] decompose(int q) {
ans = new int[q];
boolean[] dead = new boolean[N];
rec(0, dead);
return ans;
}
}
class myHashMap {
int[] freq;
myArrayList stack;
myHashMap(int N) {
freq = new int[N];
Arrays.fill(freq, Integer.MIN_VALUE);
stack = new myArrayList(N);
}
void put(int key, int val) {
if (freq[key] == Integer.MIN_VALUE) stack.add(key);
freq[key] = val;
}
int get(int key) { return freq[key]; }
int getOrDefault(int key, int defaultVal) {
if (key < 0) return defaultVal;
int ret = get(key);
return ret == Integer.MIN_VALUE ? defaultVal : ret;
}
void clear() {
while (!stack.isEmpty()) freq[stack.pop()] = Integer.MIN_VALUE;
}
boolean containsKey(int key) {
return get(key) != Integer.MIN_VALUE;
}
}
class myArrayList {
int[] arr;
int size;
myArrayList(int maxSize) { arr = new int[maxSize]; size = 0; }
void add(int ele) { arr[size++] = ele; }
int get(int ind) { return arr[ind]; }
boolean isEmpty() { return size == 0; }
int pop() { return arr[--size]; }
void sort() { Arrays.sort(arr, 0, size); }
int count(int val) { return bs(0, size - 1, val) - bs(0, size - 1, val - 1); }
int bs(int l, int r, int val) {
int mid;
while (true) {
mid = (l + r) >> 1;
if (l + 1 >= r) {
if (arr[l] > val) return l;
else if (arr[r] <= val) return r + 1;
else return r;
} else if (arr[mid] <= val) l = mid + 1;
else r = mid;
}
/*int c = 0;
for ( ; l <= r; ++l) if (arr[l] <= val) ++c;
return c;*/
}
}
class IO {
static byte[] buf = new byte[2048];
static int index, total;
static InputStream in;
static BufferedWriter bw;
IO(InputStream is, OutputStream os) {
try {
in = is;
bw = new BufferedWriter(new OutputStreamWriter(os));
} catch (Exception e) {
}
}
IO(String inputFile, String outputFile) {
try {
in = new FileInputStream(inputFile);
bw = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(outputFile)));
} catch (Exception e) {
}
}
int scan() throws Exception {
if (index >= total) {
index = 0;
total = in.read(buf);
if (total <= 0)
return -1;
}
return buf[index++];
}
String next() throws Exception {
int c;
for (c = scan(); c <= 32; c = scan()) ;
StringBuilder sb = new StringBuilder();
for (; c > 32; c = scan())
sb.append((char) c);
return sb.toString();
}
int nextInt() throws Exception {
int c, val = 0;
for (c = scan(); c <= 32; c = scan()) ;
boolean neg = c == '-';
if (c == '-' || c == '+')
c = scan();
for (; c >= '0' && c <= '9'; c = scan())
val = (val << 3) + (val << 1) + (c & 15);
return neg ? -val : val;
}
long nextLong() throws Exception {
int c;
long val = 0;
for (c = scan(); c <= 32; c = scan()) ;
boolean neg = c == '-';
if (c == '-' || c == '+')
c = scan();
for (; c >= '0' && c <= '9'; c = scan())
val = (val << 3) + (val << 1) + (c & 15);
return neg ? -val : val;
}
void print(Object a) throws Exception {
bw.write(a.toString());
}
void printsp(Object a) throws Exception {
print(a);
bw.write(" ");
}
void println() throws Exception {
bw.write("\n");
}
void println(Object a) throws Exception {
print(a);
println();
}
void flush() throws Exception {
bw.flush();
bw.close();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment