Skip to content

Instantly share code, notes, and snippets.

@SansPapyrus683
Last active December 24, 2021 01:33
Show Gist options
  • Save SansPapyrus683/2e08b039577bad85bcc70275e9132d7c to your computer and use it in GitHub Desktop.
Save SansPapyrus683/2e08b039577bad85bcc70275e9132d7c to your computer and use it in GitHub Desktop.
Solution for "Help Yourself" (2020 USACO Gold)

usaco time
i've been doing some old problems that i failed at before
and one of them was this problem, and i thought the solution was hot garbage (no offense benq) so i'm deciding to write my own that actually makes sense

so
we have 10^5 line segments, standard stuff
and we want the sum of the complexities of each segment, where the complexity is the number of connected components in the union of a segment

first, because i hate unordered stuff, let's sort the segments by starting point and then process them in that order, keeping a running total of the current complexity

now, let's see...we have a segment whose starting point is greater than all the others along with some previous segments whose endpoints may or may not be less than the current start point

so say we're processing the last segment
this results in x2 the # of previous subsets, as for each previous subset we can keep it or add the new segment
but let's just talk about the part where we add the new segment
NOTICE! that this new segment can only either

  1. add 1 to the complexity (as it starts a new seg) and this is if the segment occurs strictly after the latest endpoint of all the segments in the union
  2. have the complexity be the same (i.e. it can't "merge" any previous segments) otherwise

(idk how to 100% prove this, just trust me that it works bro i swear it does)

so when we process a new segment, we first double the current total then add
2^(number of segments that occur strictly before the current one)
why? uuuuhh

look given that case 1 happens only when the segment occurs after the latest endpoint, it also means that it only happens if all of the previous subset was before the current start, so we need to add (again) the # of previous subsets, which is that 2^ thing

actually lemme just given an example here:

[[[--]][--[----]]---]
segments: [[0, 20], [1, 6], [2, 5], [7, 16], [10, 15]]

closing brackets represent the start of a segment & ending brackets represent the end
now say we're processing the last segment in that list
the previous complexity is- actually that doesn't matter, let's just call it total
so first we double total, and then we count the # of case 1's
[10, 15] is completely ahead of 2 segments, so we add 2^2 = 4 to the already doubled total

phew!
now it remains to find the # of segments that occur strictly before the current one
you can do this w/ prefix sums, but my dumb self thought that i had to do this w/ a goddamn BIT

but anyways, if you want the solution, here it is

#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>

using std::cout;
using std::endl;
using std::vector;

constexpr int MOD = 1e9 + 7;

// i mean a prefix sum would also work but i'm too lazy
class BITree {
    private:
        vector<int> bit;
        int size;
    public:
        BITree(int size) : bit(size + 1), size(size) { }

        void increment(int updateAt, int val) {
            updateAt++;  // have the driver code not worry about 1-indexing
            for (; updateAt <= size; updateAt += updateAt & -updateAt) {
                bit[updateAt] += val;
            }
        }

        int query(int ind) {  // sum of elements in [0, ind]
            ind++;
            int sum = 0;
            for (; ind > 0; ind -= ind & -ind) {
                sum += bit[ind];
            }
            return sum;
        }
};

// 2020 feb gold
int main() {
    std::ifstream read("help.in");
    int range_num;
    read >> range_num;

    vector<int> two_pows(range_num + 1);
    two_pows[0] = 1;
    for (int i = 1; i < two_pows.size(); i++) {
        two_pows[i] = (two_pows[i - 1] * 2) % MOD;
    }

    vector<std::pair<int, int>> ranges(range_num);
    // going to assume valid input, screw you validate your own input
    BITree prev_ranges(2 * range_num);
    for (std::pair<int, int>& r : ranges) {
        read >> r.first >> r.second;
        r.first--;
        r.second--;
        prev_ranges.increment(r.second, 1);
    }
    std::sort(ranges.begin(), ranges.end());

    long long total = 0;
    for (int r = 0; r < range_num; r++) {
        int start = ranges[r].first;
        int prev = start != 0 ? prev_ranges.query(start - 1) : 0;
        total = (total * 2 + two_pows[prev]) % MOD;
    }
    cout << total << endl;
    std::ofstream("help.out") << total << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment