Skip to content

Instantly share code, notes, and snippets.

@gidgid
Created February 5, 2021 16:32
Show Gist options
  • Save gidgid/f2512972f708ea002b2aa0daa54f0399 to your computer and use it in GitHub Desktop.
Save gidgid/f2512972f708ea002b2aa0daa54f0399 to your computer and use it in GitHub Desktop.
shows how to use more_itertools one to extract a single root
import pytest
from more_itertools import one
import networkx as nx
class GraphGenerationError(Exception):
pass
def find_root(graph) -> str:
possible_roots = [node for node in graph.nodes if graph.in_degree(node) == 0] # 1
return one( # 2
possible_roots,
too_short=GraphGenerationError("No root found"), # 3
too_long=GraphGenerationError(
"Found #{len(possible_roots)} roots. Graph is corrupted"
), # 4
)
def test_finds_a_single_root():
graph = nx.DiGraph()
edges = [
("node1", "node2"),
("node2", "node3"),
("node2", "node4"),
("node3", "node4"),
]
graph.add_edges_from(edges)
root = find_root(graph)
assert root == "node1"
def test_raises_an_error_when_no_roots_found():
graph = nx.DiGraph()
edges = [
("node1", "node2"),
("node2", "node1"),
]
graph.add_edges_from(edges)
with pytest.raises(GraphGenerationError):
find_root(graph)
def test_raises_an_error_when_graph_has_more_than_one_root():
graph = nx.DiGraph()
edges = [
("node1", "node2"),
("node3", "node4"),
("node5", "node6"),
]
graph.add_edges_from(edges)
with pytest.raises(GraphGenerationError):
find_root(graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment