Skip to content

Instantly share code, notes, and snippets.

@Gabrielgtt
Last active November 22, 2019 03:57
Show Gist options
  • Save Gabrielgtt/7baf81e9e612b9fc040cb784eda9abde to your computer and use it in GitHub Desktop.
Save Gabrielgtt/7baf81e9e612b9fc040cb784eda9abde to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
#define MAXN 51010
#define MAXLOG 23
#define ll long long
#define sz(a) ((int) a.size())
using namespace std;
vector <int> grafo[MAXN];
vector <int> primos;
bool crivo[MAXN];
ll N;
void prePrimos(){
for (int i=2; i<MAXN; i++){
if (!crivo[i]){
primos.emplace_back(i);
for (int j=i+i; j<MAXN; j+=i) crivo[j] = true;
}
}
}
int dists[MAXN], sub[MAXN];
bool sat[MAXN]; // Nodes saturados que n devem ser mais visitados
int root;
int dfs(int u, int p) {
sub[u] = 1;
for (int v : grafo[u])
if (v != p && !sat[v]) sub[u] += dfs(v, u);
return sub[u];
}
int dfs(int u, int p, int n) {
for (int v : grafo[u])
if (v != p && !sat[v] and sub[v] > n/2) return dfs(v, u, n);
return u;
}
void removeDists(int node, int pai, ll prof = 1){
if (sat[node]) return;
dists[prof]--;
for (int filho : grafo[node])
if (filho != pai)
removeDists(filho, node, prof+1);
}
void addDists(int node, int pai, ll prof = 1) {
if (sat[node]) return;
dists[prof]++;
for (int filho : grafo[node])
if (filho != pai)
addDists(filho, node, prof+1);
}
ll computa(int node, int pai, ll prof = 1) {
if (sat[node]) return 0;
ll res = 0;
for (int primo : primos) {
int tam = primo - prof;
if (tam >= 0) {
res += (ll) dists[tam];
}
}
for (int filho : grafo[node]) {
if (filho != pai) {
res += computa(filho, node, prof+1);
}
}
return res;
}
ll solve(int u, int p) {
ll res = 0;
int n = dfs(u, p);
int centroid = dfs(u, p, n);
sat[centroid] = true;
dists[0] = 1;
for (int v : grafo[centroid])
res += computa(v, centroid),
addDists(v, centroid);
for (int v : grafo[centroid])
removeDists(v, centroid);
for (int v : grafo[centroid])
if (!sat[v]) res += solve(v, centroid);
return res;
}
int main() {
#ifdef LOCAL
freopen("input", "r", stdin);
#endif
prePrimos();
scanf("%lld", &N);
int de, para;
for (int i=0; i<N-1; i++){
scanf("%d %d", &de, &para);
grafo[de].emplace_back(para);
grafo[para].emplace_back(de);
}
ll distPrimas = solve(1, -1);
ll total = (N * (N-1LL)) / 2LL;
printf("%.10f\n", ((double) distPrimas) / ((double) total) );
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment