Skip to content

Instantly share code, notes, and snippets.

@fgshun
Last active August 14, 2023 15:31
Show Gist options
  • Save fgshun/615ff736c2d11341316a4ca0ca7bb0ce to your computer and use it in GitHub Desktop.
Save fgshun/615ff736c2d11341316a4ca0ca7bb0ce to your computer and use it in GitHub Desktop.
LCA 最近共通祖先
import sys
from io import BytesIO
class LCA:
"""LCA 最近共通祖先
参考
アルゴリズムロジック - ダブリングによる木の最近共通祖先(LCA:Lowest Common Ancestor)を求めるアルゴリズム
https://algo-logic.info/lca/
"""
def __init__(self, tree, root=0):
self.tree = tree
self.root = root
V = len(tree)
K = 1
while (1 << K) < V:
K += 1
# ありうる最大の深さ V - 1 を表現するに足るビット数
self.K = K
# parent[k][u]: u の 2**k 先の祖先
self.parent = parent = [[-1] * V for _ in range(K)]
# root からの距離
self.dist = [-1] * V
# dist と parent[0] を求める
self._dfs(root, -1, 0)
# 2**k 先の祖先を求める
for k in range(K - 1):
for v in range(V):
if parent[k][v] < 0:
# parent[k + 1][v] = -1
pass
else:
parent[k + 1][v] = parent[k][parent[k][v]]
def _dfs(self, cur, parent, distance):
""" root からの距離 dist と 各頂点の直接の祖先 parent[0] を求める"""
self.parent[0][cur] = parent
self.dist[cur] = distance
for e in self.tree[cur]:
if e != parent: # 逆流防止
self._dfs(e, cur, distance + 1)
def get_lca(self, u, v):
"""u と v の LCA 最近共通祖先を求める"""
dist = self.dist
parent = self.parent
if dist[u] < dist[v]:
u, v = v, u # u の方が深くなるよう取り換えておく
# LCA までの距離を同じにする
dist_v = dist[v]
for k, parent_k in enumerate(parent):
if dist[u] - dist_v >> k & 1:
u = parent_k[u] # parent[k][u]
# 二分探索で LCA を求める
if u == v:
return u
# for k in range(self.K - 1, -1, -1):
# if parent[k][u] != parent[k][v]:
# u = parent[k][u]
# v = parent[k][u]
for parent_k in reversed(parent):
parent_k_u = parent_k[u]
parent_k_v = parent_k[v]
if parent_k_u != parent_k_v:
u = parent_k_u
v = parent_k_v
return parent[0][u]
def get_dist(self, u, v):
"""u と v の距離を求める"""
return self.dist[u] + self.dist[v] - 2 * self.dist[self.get_lca(u, v)]
def main():
S = b"""7
0 1
0 2
0 3
1 4
1 5
5 6
"""
in_ = BytesIO(S)
N = int(next(in_))
AB = tuple(tuple(map(int, line.split())) for line in in_)
G = [[] for _ in range(N)]
for A, B in AB:
G[A].append(B)
G[B].append(A)
lca_solver = LCA(G, 0)
print(lca_solver.get_lca(0, 1), lca_solver.get_dist(0, 1))
print(lca_solver.get_lca(1, 2), lca_solver.get_dist(1, 2))
print(lca_solver.get_lca(5, 6), lca_solver.get_dist(5, 6))
print(lca_solver.get_lca(1, 6), lca_solver.get_dist(1, 6))
print(lca_solver.get_lca(4, 6), lca_solver.get_dist(4, 6))
print(lca_solver.get_lca(3, 6), lca_solver.get_dist(3, 6))
if __name__ == '__main__':
sys.setrecursionlimit(1_000_000)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment