hi it's me again
i should probably make a cf blog or something but the people on there are more unpredictable than my history teacher
so i came upon this problem and god the editorial was awful
so we have a tree with nodes and values, normal crap
and they want us to find how many subtrees have a sum of t
(t <= 100
)
this is actually a pretty standard tree dp problem, we just keep an array sum_ways
and i defined it like this:
sum_ways[n][s] = number of subtrees rooted at node n that have a sum s
seems pretty simple to understand (the answer we output is just the sum of sum_ways[n][t]
across all n
)
to handle the base case for a leaf (and every node), we just increment sum_ways[node][node_value]
by 1
now to actually merge a child's array into a parent's array- the actual dp part of tree dp
let's say our parent and child array were something like this (who cares if these are actually possible):
parent = [0, 0, 4, 1, 5, 0, 3] (let's just say t = 6)
child = [0, 1, 3, 3, 0, 1, 1]
how do we merge these?
let's just start with the target value because why not
there's 7 ways for the parent & child to come together and make a new subtree that sums to 6:
- parent contributes 0, child contributes 6
- parent contributes 1, child contributes 5
- parent contributes 2, child contributes 4
- ...
- parent contributes 6, child contributes 0
and for each of these p_contrib
and c_contrib
values, we just increment parent[t]
by p_contrib * c_contrib
and then we win
"but wait, kevin!" i can hear you typing in the comment sections already
"this dp relation only accounts for the current child adding on to what the parent already has with the others!"
"it doesn't account for just the parent node and the child's subtrees alone!"
but note that our base case of incrementing sum_ways[node][node_value]
by 1 accounted for that! checkmate, liberal!
but enough of me, you can just take the code and go now ig
#include <iostream>
#include <cassert>
#include <vector>
#include <algorithm>
using std::cout;
using std::endl;
using std::vector;
using std::pair;
constexpr int MOD = 1e9 + 7;
// some of the values here have different names than in the editorial, i hope that doesn't ruin things for you
class Tree {
private:
static const int ROOT = 0;
const vector<vector<int>>& neighbors;
const vector<int>& node_vals;
vector<vector<long long>> sum_ways;
int target;
void process_sums(int at, int prev) {
int val = node_vals[at]; // just a shorthand
if (val <= target) {
sum_ways[at][val]++;
}
for (int n : neighbors[at]) {
if (n == prev) {
continue;
}
process_sums(n, at);
if (val > target) {
continue;
}
for (int t = target; t >= 0; t--) {
int new_val = sum_ways[at][t]; // this is so the update can happen simultaneously
for (int a = 0; a <= t; a++) {
int b = t - a;
new_val = (
new_val + sum_ways[at][a] * sum_ways[n][b]
) % MOD;
}
sum_ways[at][t] = new_val;
}
}
}
public:
Tree(const vector<vector<int>>& neighbors,
const vector<int>& node_vals,
int target)
: neighbors(neighbors),
node_vals(node_vals),
target(target),
sum_ways(neighbors.size(), vector<long long>(target + 1)) {
assert(node_vals.size() == neighbors.size());
process_sums(ROOT, ROOT);
}
long long sum_num(int n) {
return sum_ways[n][target];
}
};
/**
* https://bit.ly/3itQFzv (actual url way too long i promise it's nothing sus)
* (input ommitted due to length)
*/
int main() {
int test_num;
std::cin >> test_num;
for (int t = 0; t < test_num; t++) {
int node_num;
int target;
std::cin >> node_num >> target;
vector<int> node_vals(node_num);
for (int& v : node_vals) {
std::cin >> v;
}
vector<vector<int>> neighbors(node_num);
for (int e = 0; e < node_num - 1; e++) {
int a;
int b;
std::cin >> a >> b;
neighbors[--a].push_back(--b);
neighbors[b].push_back(a);
}
Tree tree(neighbors, node_vals, target);
long long total_sums = 0;
for (int n = 0; n < node_num; n++) {
total_sums = (total_sums + tree.sum_num(n)) % MOD;
}
cout << total_sums << endl;
}
}
wait
i just realized
values can only be 0 or 1
lmao
don't think that changes anything tho