Skip to content

Instantly share code, notes, and snippets.

@AhmetCanSolak
Last active May 26, 2018 08:45
Show Gist options
  • Save AhmetCanSolak/59c8df5e7dd8bb22925ccb354280b596 to your computer and use it in GitHub Desktop.
Save AhmetCanSolak/59c8df5e7dd8bb22925ccb354280b596 to your computer and use it in GitHub Desktop.
My solution to Kundu and Tree problem on HackerRank/Data Structures/DisjointSet
// for problem text https://www.hackerrank.com/challenges/kundu-and-tree/problem
#include <bits/stdc++.h>
typedef long long int LT; // suitable long type
using namespace std;
template<typename T>
class disjointset
{
public:
disjointset(int n) {
size = n;
arr.reserve(n);
rank.reserve(n);
makeSet();
setcount = n;
}
void makeSet() {
for(size_t i=0;i<size;i++){
arr[i] = i;
rank[i] = 1;
}
}
T find(T a) {
if (arr[a]!=a)
arr[a] = find(arr[a]);
return arr[a];
}
void unite(T a,T b) {
T arep = find(a);
T brep = find(b);
if (arep == brep) {
return;
} else if (rank[a] < rank[b]) {
arr[brep] = arep;
rank[arep] += rank[brep];
rank[brep]=0;
} else {
arr[arep] = brep;
rank[brep] += rank[arep];
rank[arep]=0;
}
}
T getRank(T a){
return rank[a];
}
private:
int size;
int setcount;
vector<T> arr;
vector<T> rank;
};
int main() {
/* Enter your code here. Read input from STDIN. Print output to STDOUT */
int n;
cin >> n;
disjointset<int> ds(n);
int numoflines = n-1, a, b;
char ch;
for(;numoflines;--numoflines) {
cin >> a >> b >> ch;
a--;
b--;
if(ch=='b') {
ds.unite(a,b);
}
}
vector<long long int> vec;
for(int i=0;i<n;i++) {
if(ds.getRank(i)>1)
vec.push_back((LT)ds.getRank(i));
}
// For math behind check the link: https://math.stackexchange.com/questions/838792/counting-triplets-with-red-edges-in-each-pair?newreg=60eee35f0b3844de852bda39f6dfec88
LT mod = 1000000007;
LT result = 0;
LT nCom3 = (((LT)n*(n - 1)*(n - 2)) / 6);
LT binom2 = 0, binom3 = 0;
for (auto &elem:vec) {
binom3 = 0;
binom2 = 0;
if (elem > 2)
binom3 += ((LT)elem * (elem - 1)*(elem - 2)) / 6;
binom2 = ((LT)elem * (elem - 1)) / 2;
result += (binom3 + binom2 * ((LT)n - elem));
}
cout << (nCom3 - result)%mod << endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment