Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
This code reuses the recursive function from Lowest Common Ancestor of Binary Tree (#236)
Step1: We want to find the LCA since it is guaranteed to be on the path between p and q.
Since we're unable to move "up" in a binary tree, we can simple move down in both directions from the LCA to form the path.
Step2: Return the distance from the LCA -> p, plus the distance from the LCA -> q.
Time complexity is O(n) since we have to look at all the nodes at least once in the worst case.
Space complexity is O(logn) for the call stack.
Also, this question shouldn't be a medium considering that #236 is a medium itself lol.
def findDistance(self, root: TreeNode, p: int, q: int) -> int:
def _lca(r, p, q):
"""find lca"""
if not r:
if r.val == p or r.val == q:
return r
left = _lca(r.left, p, q)
right = _lca(r.right, p, q)
if left and right:
return r
if not left:
return right
if not right:
return left
def dfs(r, target):
we either search target on the left subtree of lca or right subtree of lca to calculate distance
if not r:
return float("-inf")
if r.val == target:
return 0
return 1 + max(dfs(r.left, target), dfs(r.right, target))
lca = _lca(root, p, q)
d1, d2 = dfs(lca, p), dfs(lca, q)
return d1+d2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment