-
-
Save meooow25/e31bcd754c07e94187fc86811a0752b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// CLTREE | |
// Tester: sarthakmanna | |
import java.io.*; | |
import java.util.*; | |
class CP { | |
public static void main(String[] args) throws Exception { | |
/*new Thread(null, new Runnable() { | |
@Override | |
public void run() { | |
try { | |
new Solver().solve(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
System.exit(1); | |
} | |
} | |
}, "Solver", 1l << 30).start();*/ | |
new Solver().solve(); | |
} | |
} | |
class Solver { | |
IO io = new IO(System.in, System.out); | |
int N, X, Y, Z; | |
ArrayList<Integer>[] graph; | |
void solve() throws Exception { | |
int i, j, k; | |
for (int tc = io.nextInt(); tc > 0; --tc) { | |
N = io.nextInt(); X= io.nextInt(); Y = io.nextInt(); Z = io.nextInt(); | |
graph = new ArrayList[N]; | |
for (i = 0; i < N; ++i) graph[i] = new ArrayList<>(); | |
for (i = 0; i < N - 1; ++i) { | |
int a = io.nextInt() - 1, b = io.nextInt() - 1; | |
graph[a].add(b); graph[b].add(a); | |
} | |
ar = new int[N]; in = new int[N]; out = new int[N]; | |
K = 0; dfs(0, -7, 7); | |
st = new SegmentTree(ar); | |
A = X + Y - Z; B = X - Y + Z; C = -X + Y + Z; | |
if (A % 2 > 0 || B % 2 > 0 || C % 2 > 0) { | |
io.println(0); | |
continue; | |
} | |
A >>= 1; B >>= 1; C >>= 1; | |
ArrayList<int[]>[] queries = new ArrayList[N]; | |
for (i = 0; i < N; ++i) { | |
j = 3 * i; | |
queries[i] = new ArrayList<>(); | |
queries[i].add(new int[] {A - 1, j + 0}); | |
queries[i].add(new int[] {B - 1, j + 1}); | |
queries[i].add(new int[] {C - 1, j + 2}); | |
} | |
cd = new CentroidDecomposition(N, graph, queries); | |
counts = cd.decompose(3 * N); | |
//System.out.println(Arrays.toString(counts)); | |
ans = 0; | |
eulerTour(0, -7); | |
io.println(ans); | |
} | |
io.flush(); | |
} | |
int[] ar, in, out, counts; | |
int A, B, C, K; | |
long ans; | |
SegmentTree st; | |
CentroidDecomposition cd; | |
void dfs(int node, int par, int dep) { | |
in[node] = K; ar[K++] = dep++; | |
for (int itr : graph[node]) if (itr != par) dfs(itr, node, dep); | |
out[node] = K - 1; | |
} | |
void eulerTour(int node, int prev) { | |
int i, j, k; | |
long[] dp = new long[8]; | |
dp[0] = 1; | |
if (A == 0) dp[1 << 0] = 1; | |
if (B == 0) dp[1 << 1] = 1; | |
if (C == 0) dp[1 << 2] = 1; | |
for (int itr : graph[node]) { | |
long[] newDP = new long[8]; | |
System.arraycopy(dp, 0, newDP, 0, 8); | |
int[] count; | |
if (itr == prev) { | |
count = new int[] {allAround(itr, A - 1) - countUnder(node, A - 2), | |
allAround(itr, B - 1) - countUnder(node, B - 2), | |
allAround(itr, C - 1) - countUnder(node, C - 2)}; | |
} else { | |
count = new int[] {countUnder(itr, A - 1), | |
countUnder(itr, B - 1), countUnder(itr, C - 1)}; | |
} | |
for (j = 0; j < 3; ++j) for (k = 0; k < 8; ++k) if (((k >> j) & 1) == 0) | |
newDP[(1 << j) | k] += count[j] * dp[k]; | |
dp = newDP; | |
} | |
ans += dp[7]; | |
for (int itr : graph[node]) if (itr != prev) eulerTour(itr, node); | |
} | |
int countUnder(int node, int dist) { | |
return st.rangeQuery(0, 0, st.N - 1, in[node], out[node], ar[in[node]] + dist); | |
} | |
int allAround(int node, int d) { | |
node *= 3; | |
if (d == A - 1) return counts[node + 0]; | |
else if (d == B - 1) return counts[node + 1]; | |
else if (d == C - 1) return counts[node + 2]; | |
else return 1/0; | |
} | |
} | |
class SegmentTree { | |
myArrayList[] tree; | |
int N; | |
SegmentTree(int[] ar) { | |
N = 1; while (N < ar.length) N <<= 1; | |
tree = new myArrayList[(N << 1) - 1]; | |
for (int i = 0; i < N; ++i) tree[N - 1 + i] = new myArrayList(1); | |
for (int i = 0; i < ar.length; ++i) tree[N - 1 + i].add(ar[i]); | |
for (int i = N - 2; i >= 0; --i) tree[i] = merge(tree[(i << 1) + 1], tree[(i << 1) + 2]); | |
} | |
myArrayList merge(myArrayList m1, myArrayList m2) { | |
myArrayList ret = new myArrayList(m1.size + m2.size); | |
int p1 = 0, p2 = 0; | |
while (p1 < m1.size || p2 < m2.size) { | |
if (p1 == m1.size) ret.add(m2.get(p2++)); | |
else if (p2 == m2.size) ret.add(m1.get(p1++)); | |
else if (m1.get(p1) < m2.get(p2)) ret.add(m1.get(p1++)); | |
else ret.add(m2.get(p2++)); | |
} | |
ret.sort(); | |
return ret; | |
} | |
int rangeQuery(int i, int l, int r, int ql, int qr, int key) { | |
int mid = (l + r) >> 1, i2 = i << 1; | |
if (l > qr || r < ql) return 0; | |
else if (l >= ql && r <= qr) return tree[i].count(key); | |
else return rangeQuery(i2 + 1, l, mid, ql, qr, key) | |
+ rangeQuery(i2 + 2, mid + 1, r, ql, qr, key); | |
} | |
} | |
class CentroidDecomposition { | |
int N; | |
ArrayList<Integer>[] graph; | |
ArrayList<int[]>[] queries; | |
CentroidDecomposition(int n, ArrayList<Integer>[] g, ArrayList<int[]>[] q) { | |
N = n; graph = g; queries = q; | |
size = new myHashMap(N + 7); | |
count = new myHashMap(N + 7); | |
} | |
// https://codeforces.com/blog/entry/58025 | |
myHashMap size, count; | |
void setSize(int node, int par, int root, boolean[] dead) { | |
size.put(node, 1); | |
for (int itr : graph[node]) if (itr != par && !dead[itr]) { | |
setSize(itr, node, root, dead); | |
size.put(node, size.getOrDefault(node, 0) + size.getOrDefault(itr, 0)); | |
} | |
} | |
int dfs(int node, int par, int n, boolean[] dead) { | |
for (int itr : graph[node]) if (itr != par && !dead[itr]) { | |
if (size.getOrDefault(itr, 0) > n >> 1) | |
return dfs(itr, node, n, dead); | |
} | |
return node; | |
} | |
int oneCentroid(int root, boolean[] dead) { | |
setSize(root, -7, root, dead); | |
return dfs(root, -7, size.getOrDefault(root, 0), dead); | |
} | |
void rec(int start, boolean[] dead) { | |
int c = oneCentroid(start, dead); | |
dead[c] = true; | |
for (int itr : graph[c]) if (!dead[itr]) | |
rec(itr, dead); | |
count.clear(); | |
addCount(c, -7, 0, true, count, dead); | |
for (int[] itr : queries[c]) { | |
int dd = itr[0], idx = itr[1]; | |
ans[idx] += count.getOrDefault(dd, 0); | |
} | |
for (Integer itr : graph[c]) if (!dead[itr]) { | |
addCount(itr, c, 1, false, count, dead); | |
calc(itr, c, 1, count, dead); | |
addCount(itr, c, 1, true, count, dead); | |
} | |
dead[c] = false; | |
} | |
void addCount(int node, int prev, int d, boolean add, myHashMap count, boolean[] dead) { | |
count.put(d, count.getOrDefault(d, 0) + (add ? 1 : -1)); | |
for (int itr : graph[node]) if (itr != prev && !dead[itr]) { | |
addCount(itr, node, d + 1, add, count, dead); | |
} | |
} | |
void calc(int node, int prev, int d, myHashMap count, boolean[] dead) { | |
for (int[] itr : queries[node]) { | |
int dd = itr[0], idx = itr[1]; | |
if (dd >= d && count.containsKey(dd - d)) | |
ans[idx] += count.get(dd - d); | |
} | |
for (int itr : graph[node]) if (itr != prev && !dead[itr]) { | |
calc(itr, node, d + 1, count, dead); | |
} | |
} | |
int[] ans; | |
int[] decompose(int q) { | |
ans = new int[q]; | |
boolean[] dead = new boolean[N]; | |
rec(0, dead); | |
return ans; | |
} | |
} | |
class myHashMap { | |
int[] freq; | |
myArrayList stack; | |
myHashMap(int N) { | |
freq = new int[N]; | |
Arrays.fill(freq, Integer.MIN_VALUE); | |
stack = new myArrayList(N); | |
} | |
void put(int key, int val) { | |
if (freq[key] == Integer.MIN_VALUE) stack.add(key); | |
freq[key] = val; | |
} | |
int get(int key) { return freq[key]; } | |
int getOrDefault(int key, int defaultVal) { | |
if (key < 0) return defaultVal; | |
int ret = get(key); | |
return ret == Integer.MIN_VALUE ? defaultVal : ret; | |
} | |
void clear() { | |
while (!stack.isEmpty()) freq[stack.pop()] = Integer.MIN_VALUE; | |
} | |
boolean containsKey(int key) { | |
return get(key) != Integer.MIN_VALUE; | |
} | |
} | |
class myArrayList { | |
int[] arr; | |
int size; | |
myArrayList(int maxSize) { arr = new int[maxSize]; size = 0; } | |
void add(int ele) { arr[size++] = ele; } | |
int get(int ind) { return arr[ind]; } | |
boolean isEmpty() { return size == 0; } | |
int pop() { return arr[--size]; } | |
void sort() { Arrays.sort(arr, 0, size); } | |
int count(int val) { return bs(0, size - 1, val) - bs(0, size - 1, val - 1); } | |
int bs(int l, int r, int val) { | |
int mid; | |
while (true) { | |
mid = (l + r) >> 1; | |
if (l + 1 >= r) { | |
if (arr[l] > val) return l; | |
else if (arr[r] <= val) return r + 1; | |
else return r; | |
} else if (arr[mid] <= val) l = mid + 1; | |
else r = mid; | |
} | |
/*int c = 0; | |
for ( ; l <= r; ++l) if (arr[l] <= val) ++c; | |
return c;*/ | |
} | |
} | |
class IO { | |
static byte[] buf = new byte[2048]; | |
static int index, total; | |
static InputStream in; | |
static BufferedWriter bw; | |
IO(InputStream is, OutputStream os) { | |
try { | |
in = is; | |
bw = new BufferedWriter(new OutputStreamWriter(os)); | |
} catch (Exception e) { | |
} | |
} | |
IO(String inputFile, String outputFile) { | |
try { | |
in = new FileInputStream(inputFile); | |
bw = new BufferedWriter(new OutputStreamWriter( | |
new FileOutputStream(outputFile))); | |
} catch (Exception e) { | |
} | |
} | |
int scan() throws Exception { | |
if (index >= total) { | |
index = 0; | |
total = in.read(buf); | |
if (total <= 0) | |
return -1; | |
} | |
return buf[index++]; | |
} | |
String next() throws Exception { | |
int c; | |
for (c = scan(); c <= 32; c = scan()) ; | |
StringBuilder sb = new StringBuilder(); | |
for (; c > 32; c = scan()) | |
sb.append((char) c); | |
return sb.toString(); | |
} | |
int nextInt() throws Exception { | |
int c, val = 0; | |
for (c = scan(); c <= 32; c = scan()) ; | |
boolean neg = c == '-'; | |
if (c == '-' || c == '+') | |
c = scan(); | |
for (; c >= '0' && c <= '9'; c = scan()) | |
val = (val << 3) + (val << 1) + (c & 15); | |
return neg ? -val : val; | |
} | |
long nextLong() throws Exception { | |
int c; | |
long val = 0; | |
for (c = scan(); c <= 32; c = scan()) ; | |
boolean neg = c == '-'; | |
if (c == '-' || c == '+') | |
c = scan(); | |
for (; c >= '0' && c <= '9'; c = scan()) | |
val = (val << 3) + (val << 1) + (c & 15); | |
return neg ? -val : val; | |
} | |
void print(Object a) throws Exception { | |
bw.write(a.toString()); | |
} | |
void printsp(Object a) throws Exception { | |
print(a); | |
bw.write(" "); | |
} | |
void println() throws Exception { | |
bw.write("\n"); | |
} | |
void println(Object a) throws Exception { | |
print(a); | |
println(); | |
} | |
void flush() throws Exception { | |
bw.flush(); | |
bw.close(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment