Skip to content

Instantly share code, notes, and snippets.

@drewbanin
Created June 6, 2022 14:10
Show Gist options
  • Save drewbanin/60f3b6a5d19d8af694dc6830c8c986b3 to your computer and use it in GitHub Desktop.
Save drewbanin/60f3b6a5d19d8af694dc6830c8c986b3 to your computer and use it in GitHub Desktop.
import networkx as nx
def ancestors(graph, node, max_depth):
with nx.utils.reversed(graph):
anc = nx.single_source_shortest_path_length(G=graph, source=node, cutoff=max_depth).keys()
return anc - {node}
def descendants(graph, node, max_depth):
des = nx.single_source_shortest_path_length(G=graph, source=node, cutoff=max_depth).keys()
return des - {node}
"""
1 --> 2 ---> 3 +
| ---> 5
+----> 4 +
"""
G = nx.DiGraph()
G.add_edge(1, 2)
G.add_edge(2, 3)
G.add_edge(2, 4)
G.add_edge(3, 5)
G.add_edge(4, 5)
anc = {
1: set(),
2: {1},
3: {1, 2},
4: {1, 2},
5: {1, 2, 3, 4},
}
desc = {
1: {2, 3, 4, 5},
2: {3, 4, 5},
3: {5},
4: {5},
5: set(),
}
for node_id, expected in anc.items():
res = ancestors(G, node_id, None)
print(f"Checking ancestors of node {node_id}")
assert res == expected, f"Got {res} expected {expected}"
for node_id, expected in desc.items():
res = descendants(G, node_id, None)
print(f"Checking descendants of node {node_id}")
assert res == expected, f"Got {res} expected {expected}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment