Skip to content

Instantly share code, notes, and snippets.

@SansPapyrus683
Last active November 26, 2021 16:31
Show Gist options
  • Save SansPapyrus683/c1eb0112b1abdf843a4844dec128874a to your computer and use it in GitHub Desktop.
Save SansPapyrus683/c1eb0112b1abdf843a4844dec128874a to your computer and use it in GitHub Desktop.
sol for counting on tree on hackerearth

hi it's me again

i should probably make a cf blog or something but the people on there are more unpredictable than my history teacher

so i came upon this problem and god the editorial was awful

so we have a tree with nodes and values, normal crap
and they want us to find how many subtrees have a sum of t (t <= 100)

this is actually a pretty standard tree dp problem, we just keep an array sum_ways and i defined it like this:
sum_ways[n][s] = number of subtrees rooted at node n that have a sum s
seems pretty simple to understand (the answer we output is just the sum of sum_ways[n][t] across all n)

to handle the base case for a leaf (and every node), we just increment sum_ways[node][node_value] by 1

now to actually merge a child's array into a parent's array- the actual dp part of tree dp
let's say our parent and child array were something like this (who cares if these are actually possible):

parent = [0, 0, 4, 1, 5, 0, 3]  (let's just say t = 6)
child = [0, 1, 3, 3, 0, 1, 1]

how do we merge these?
let's just start with the target value because why not
there's 7 ways for the parent & child to come together and make a new subtree that sums to 6:

  • parent contributes 0, child contributes 6
  • parent contributes 1, child contributes 5
  • parent contributes 2, child contributes 4
  • ...
  • parent contributes 6, child contributes 0

and for each of these p_contrib and c_contrib values, we just increment parent[t] by p_contrib * c_contrib
and then we win

"but wait, kevin!" i can hear you typing in the comment sections already
"this dp relation only accounts for the current child adding on to what the parent already has with the others!"
"it doesn't account for just the parent node and the child's subtrees alone!"
but note that our base case of incrementing sum_ways[node][node_value] by 1 accounted for that! checkmate, liberal!

but enough of me, you can just take the code and go now ig

#include <iostream>
#include <cassert>
#include <vector>
#include <algorithm>

using std::cout;
using std::endl;
using std::vector;
using std::pair;

constexpr int MOD = 1e9 + 7;

// some of the values here have different names than in the editorial, i hope that doesn't ruin things for you
class Tree {
    private:
        static const int ROOT = 0;

        const vector<vector<int>>& neighbors;
        const vector<int>& node_vals;
        vector<vector<long long>> sum_ways;
        int target;

        void process_sums(int at, int prev) {
            int val = node_vals[at];  // just a shorthand
            if (val <= target) {
                sum_ways[at][val]++;
            }
            for (int n : neighbors[at]) {
                if (n == prev) {
                    continue;
                }
                process_sums(n, at);
                if (val > target) {
                    continue;
                }
                for (int t = target; t >= 0; t--) {
                    int new_val = sum_ways[at][t];  // this is so the update can happen simultaneously
                    for (int a = 0; a <= t; a++) {
                        int b = t - a;
                        new_val = (
                            new_val + sum_ways[at][a] * sum_ways[n][b]
                        ) % MOD;
                    }
                    sum_ways[at][t] = new_val;
                }
            }
        }
    public:
        Tree(const vector<vector<int>>& neighbors,
             const vector<int>& node_vals,
             int target)
                : neighbors(neighbors),
                  node_vals(node_vals),
                  target(target),
                  sum_ways(neighbors.size(), vector<long long>(target + 1)) {
            assert(node_vals.size() == neighbors.size());
            process_sums(ROOT, ROOT);
        }

        long long sum_num(int n) {
            return sum_ways[n][target];
        }
};

/**
 * https://bit.ly/3itQFzv (actual url way too long i promise it's nothing sus)
 * (input ommitted due to length)
 */
int main() {
    int test_num;
    std::cin >> test_num;
    for (int t = 0; t < test_num; t++) {
        int node_num;
        int target;
        std::cin >> node_num >> target;
        vector<int> node_vals(node_num);
        for (int& v : node_vals) {
            std::cin >> v;
        }
        
        vector<vector<int>> neighbors(node_num);
        for (int e = 0; e < node_num - 1; e++) {
            int a;
            int b;
            std::cin >> a >> b;
            neighbors[--a].push_back(--b);
            neighbors[b].push_back(a);
        }

        Tree tree(neighbors, node_vals, target);
        long long total_sums = 0;
        for (int n = 0; n < node_num; n++) {
            total_sums = (total_sums + tree.sum_num(n)) % MOD;
        }
        cout << total_sums << endl;
    }
}
@SansPapyrus683
Copy link
Author

wait
i just realized
values can only be 0 or 1
lmao
don't think that changes anything tho

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment