Skip to content

Instantly share code, notes, and snippets.

@SansPapyrus683
Created May 19, 2022 04:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SansPapyrus683/49e069ef13eae549f2e7ec1b4f1d5b22 to your computer and use it in GitHub Desktop.
Save SansPapyrus683/49e069ef13eae549f2e7ec1b4f1d5b22 to your computer and use it in GitHub Desktop.
Tree with LCA capabilities (C++ and Java)
#include <iostream>
#include <vector>
#include <cmath>
using std::vector;
/** Tree class with LCA capabilities and some other handy functions. */
class Tree {
private:
vector<int> par;
vector<vector<int>> pow2ends;
vector<int> depth;
const int log2dist;
public:
/**
* Constructs a tree based on a given parent array.
* @param parents
* The parent array. The root (assumed to be node 0) is not included.
*/
Tree(const vector<int>& parents)
: par(parents.size() + 1),
log2dist(std::ceil(std::log2(parents.size() + 1))) {
par[0] = -1;
for (int i = 0; i < parents.size(); i++) {
par[i + 1] = parents[i];
}
pow2ends = vector<vector<int>>(par.size(), vector<int>(log2dist + 1));
for (int n = 0; n < par.size(); n++) {
pow2ends[n][0] = par[n];
}
for (int p = 1; p <= log2dist; p++) {
for (int n = 0; n < par.size(); n++) {
int halfway = pow2ends[n][p - 1];
if (halfway == -1) {
pow2ends[n][p] = -1;
} else {
pow2ends[n][p] = pow2ends[halfway][p - 1];
}
}
}
vector<vector<int>> children(par.size());
for (int n = 1; n < par.size(); n++) {
children[par[n]].push_back(n);
}
depth = vector<int>(par.size());
vector<int> frontier{0};
while (!frontier.empty()) {
int curr = frontier.back();
frontier.pop_back();
for (int n : children[curr]) {
depth[n] = depth[curr] + 1;
frontier.push_back(n);
}
}
}
/** @return the kth parent of node n (or -1 if it doesn't exist). */
int kth_parent(int n, int k) {
if (k > par.size()) {
return -1;
}
int at = n;
for (int pow = 0; pow <= log2dist; pow++) {
if ((k & (1 << pow)) != 0) {
at = pow2ends[at][pow];
if (at == -1) {
break;
}
}
}
return at;
}
/** @return the lowest common ancestor of n1 and n2. */
int LCA(int n1, int n2) {
if (depth[n1] < depth[n2]) {
return LCA(n2, n1);
}
n1 = kth_parent(n1, depth[n1] - depth[n2]);
if (n1 == n2) {
return n1;
}
for (int i = log2dist; i >= 0; i--) {
if (pow2ends[n1][i] != pow2ends[n2][i]) {
n1 = pow2ends[n1][i];
n2 = pow2ends[n2][i];
}
}
return n1 == 0 ? 0 : pow2ends[n1][0];
}
/** @return the distance between n1 and n2. */
int dist(int n1, int n2) {
return depth[n1] + depth[n2] - 2 * depth[LCA(n1, n2)];
}
};
import java.util.*;
/** Tree class with LCA capabilities and some other handy functions. */
public class LCATree {
private final int[] par;
private final int[][] pow2Ends;
private final int[] depth;
private final int log2Dist;
/**
* Constructs a tree based on a given parent array.
* @param parents
* The parent array. The root (assumed to be node 0) is not included.
*/
public LCATree(int[] parents) {
par = new int[parents.length + 1];
par[0] = -1;
System.arraycopy(parents, 0, par, 1, parents.length);
log2Dist = (int) Math.ceil(Math.log(par.length) / Math.log(2));
pow2Ends = new int[par.length][log2Dist + 1];
for (int n = 0; n < par.length; n++) {
pow2Ends[n][0] = par[n];
}
for (int p = 1; p <= log2Dist; p++) {
for (int n = 0; n < par.length; n++) {
int halfway = pow2Ends[n][p - 1];
if (halfway == -1) {
pow2Ends[n][p] = -1;
} else {
pow2Ends[n][p] = pow2Ends[halfway][p - 1];
}
}
}
List<Integer>[] children = new ArrayList[par.length];
for (int n = 0; n < par.length; n++) {
children[n] = new ArrayList<>();
}
for (int n = 1; n < par.length; n++) {
children[par[n]].add(n);
}
depth = new int[par.length];
Deque<Integer> frontier = new ArrayDeque<>(Collections.singletonList(0));
while (!frontier.isEmpty()) {
int curr = frontier.poll();
for (int c : children[curr]) {
depth[c] = depth[curr] + 1;
frontier.add(c);
}
}
}
/** @return the kth parent of node n (or -1 if it doesn't exist). */
public int kthParent(int n, int k) {
if (k > par.length) {
return -1;
}
int at = n;
for (int pow = 0; pow <= log2Dist; pow++) {
if ((k & (1 << pow)) != 0) {
at = pow2Ends[at][pow];
if (at == -1) {
break;
}
}
}
return at;
}
/** @return the lowest common ancestor of n1 and n2. */
public int LCA(int n1, int n2) {
if (depth[n1] < depth[n2]) {
return LCA(n2, n1);
}
n1 = kthParent(n1, depth[n1] - depth[n2]);
if (n1 == n2) {
return n1;
}
for (int i = log2Dist; i >= 0; i--) {
if (pow2Ends[n1][i] != pow2Ends[n2][i]) {
n1 = pow2Ends[n1][i];
n2 = pow2Ends[n2][i];
}
}
return n1 == 0 ? 0 : pow2Ends[n1][0];
}
/** @return the distance between n1 and n2. */
public int dist(int n1, int n2) {
return depth[n1] + depth[n1] - 2 * depth[LCA(n1, n2)];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment