Last active
August 14, 2023 15:31
-
-
Save fgshun/615ff736c2d11341316a4ca0ca7bb0ce to your computer and use it in GitHub Desktop.
LCA 最近共通祖先
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from io import BytesIO | |
class LCA: | |
"""LCA 最近共通祖先 | |
参考 | |
アルゴリズムロジック - ダブリングによる木の最近共通祖先(LCA:Lowest Common Ancestor)を求めるアルゴリズム | |
https://algo-logic.info/lca/ | |
""" | |
def __init__(self, tree, root=0): | |
self.tree = tree | |
self.root = root | |
V = len(tree) | |
K = 1 | |
while (1 << K) < V: | |
K += 1 | |
# ありうる最大の深さ V - 1 を表現するに足るビット数 | |
self.K = K | |
# parent[k][u]: u の 2**k 先の祖先 | |
self.parent = parent = [[-1] * V for _ in range(K)] | |
# root からの距離 | |
self.dist = [-1] * V | |
# dist と parent[0] を求める | |
self._dfs(root, -1, 0) | |
# 2**k 先の祖先を求める | |
for k in range(K - 1): | |
for v in range(V): | |
if parent[k][v] < 0: | |
# parent[k + 1][v] = -1 | |
pass | |
else: | |
parent[k + 1][v] = parent[k][parent[k][v]] | |
def _dfs(self, cur, parent, distance): | |
""" root からの距離 dist と 各頂点の直接の祖先 parent[0] を求める""" | |
self.parent[0][cur] = parent | |
self.dist[cur] = distance | |
for e in self.tree[cur]: | |
if e != parent: # 逆流防止 | |
self._dfs(e, cur, distance + 1) | |
def get_lca(self, u, v): | |
"""u と v の LCA 最近共通祖先を求める""" | |
dist = self.dist | |
parent = self.parent | |
if dist[u] < dist[v]: | |
u, v = v, u # u の方が深くなるよう取り換えておく | |
# LCA までの距離を同じにする | |
dist_v = dist[v] | |
for k, parent_k in enumerate(parent): | |
if dist[u] - dist_v >> k & 1: | |
u = parent_k[u] # parent[k][u] | |
# 二分探索で LCA を求める | |
if u == v: | |
return u | |
# for k in range(self.K - 1, -1, -1): | |
# if parent[k][u] != parent[k][v]: | |
# u = parent[k][u] | |
# v = parent[k][u] | |
for parent_k in reversed(parent): | |
parent_k_u = parent_k[u] | |
parent_k_v = parent_k[v] | |
if parent_k_u != parent_k_v: | |
u = parent_k_u | |
v = parent_k_v | |
return parent[0][u] | |
def get_dist(self, u, v): | |
"""u と v の距離を求める""" | |
return self.dist[u] + self.dist[v] - 2 * self.dist[self.get_lca(u, v)] | |
def main(): | |
S = b"""7 | |
0 1 | |
0 2 | |
0 3 | |
1 4 | |
1 5 | |
5 6 | |
""" | |
in_ = BytesIO(S) | |
N = int(next(in_)) | |
AB = tuple(tuple(map(int, line.split())) for line in in_) | |
G = [[] for _ in range(N)] | |
for A, B in AB: | |
G[A].append(B) | |
G[B].append(A) | |
lca_solver = LCA(G, 0) | |
print(lca_solver.get_lca(0, 1), lca_solver.get_dist(0, 1)) | |
print(lca_solver.get_lca(1, 2), lca_solver.get_dist(1, 2)) | |
print(lca_solver.get_lca(5, 6), lca_solver.get_dist(5, 6)) | |
print(lca_solver.get_lca(1, 6), lca_solver.get_dist(1, 6)) | |
print(lca_solver.get_lca(4, 6), lca_solver.get_dist(4, 6)) | |
print(lca_solver.get_lca(3, 6), lca_solver.get_dist(3, 6)) | |
if __name__ == '__main__': | |
sys.setrecursionlimit(1_000_000) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment