Skip to content

Instantly share code, notes, and snippets.

@jtribble
Created August 1, 2020 19:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jtribble/c9512980c5ac827c8108099ba771c88e to your computer and use it in GitHub Desktop.
Save jtribble/c9512980c5ac827c8108099ba771c88e to your computer and use it in GitHub Desktop.
A simple graph represention in Python along with some common methods, like breadth-first traversal, depth-first traversal, and shortest path.
#!/usr/bin/env python3
from collections import deque
from dataclasses import dataclass, field
from typing import Set, Optional, List, Generator, Dict, Deque
from unittest import TestCase, main
@dataclass
class GraphNode:
label: str
neighbors: Set['GraphNode'] = field(default_factory=set)
color: Optional[str] = None
def __hash__(self) -> int:
return hash(self.label)
def __repr__(self) -> str:
return f'GraphNode(label={self.label}, color={self.color}, neighbors=[{", ".join(n.label for n in self.neighbors)}])'
@dataclass
class Graph:
nodes: Set[GraphNode] = field(default_factory=set)
def breadth_first_traversal(self, first_node: Optional[GraphNode] = None) -> Generator[GraphNode, None, None]:
visited = set() # type: Set[GraphNode]
if first_node is None:
first_node = next(iter(self.nodes))
visited.add(first_node)
queue = deque([first_node]) # type: Deque[GraphNode]
while queue:
node = queue.popleft()
yield node
for neighbor in node.neighbors:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
def depth_first_traversal(self, first_node: Optional[GraphNode] = None) -> Generator[GraphNode, None, None]:
visited = set() # type: Set[GraphNode]
if first_node is None:
first_node = next(iter(self.nodes))
visited.add(first_node)
stack = [first_node] # type: List[GraphNode]
while stack:
node = stack.pop()
yield node
for neighbor in node.neighbors:
if neighbor not in visited:
visited.add(neighbor)
stack.append(neighbor)
def apply_colors(self, colors: Set[str]) -> None:
for node in self.nodes:
illegal_colors = set(n.color for n in node.neighbors if n.color is not None)
for color in colors:
if color not in illegal_colors:
node.color = color
break
else:
raise ValueError(f"Can't apply legal color to {node}")
@staticmethod
def shortest_path(source: GraphNode, destination: GraphNode) -> List[GraphNode]:
queue = deque([[source]]) # type: Deque[List[GraphNode]]
while queue:
path = queue.popleft()
last_node = path[-1]
for neighbor in last_node.neighbors:
if neighbor == destination:
path.append(destination)
return path
if neighbor not in path:
queue.append(path + [neighbor])
raise ValueError(f'Could not find path from {source} to {destination}')
# These aren't really tests—most of the methods just print out state for you to observe.
class UnitTests(TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nodes = {} # type: Dict[int, GraphNode]
for n in range(1, 13):
nodes[n] = GraphNode(str(n))
nodes[1].neighbors.update({nodes[2], nodes[3], nodes[5]})
nodes[2].neighbors.update({nodes[1], nodes[5], nodes[7]})
nodes[3].neighbors.update({nodes[1], nodes[5], nodes[6]})
nodes[4].neighbors.update({nodes[2], nodes[6], nodes[7]})
nodes[5].neighbors.update({nodes[1], nodes[3], nodes[10]})
nodes[6].neighbors.update({nodes[3], nodes[4], nodes[5]})
nodes[7].neighbors.update({nodes[2], nodes[4], nodes[11]})
nodes[8].neighbors.update({nodes[6], nodes[9], nodes[11]})
nodes[9].neighbors.update({nodes[8], nodes[10], nodes[12]})
nodes[10].neighbors.update({nodes[5], nodes[9], nodes[12]})
nodes[11].neighbors.update({nodes[7], nodes[8], nodes[12]})
nodes[12].neighbors.update({nodes[9], nodes[10], nodes[11]})
self.graph = Graph(set(nodes.values()))
self.nodes = nodes
def test_breadth_first_traversal(self) -> None:
print('Breadth-first traversal:')
for i, node in enumerate(self.graph.breadth_first_traversal(self.nodes[1])):
print(f'{str(i + 1).rjust(2)}. {node}')
def test_depth_first_traversal(self) -> None:
print('Depth-first traversal:')
for i, node in enumerate(self.graph.depth_first_traversal(self.nodes[1])):
print(f'{str(i + 1).rjust(2)}. {node}')
def test_legal_graph_coloring(self) -> None:
colors = {'red', 'blue', 'green', 'yellow'}
print(f'Coloring graph with {len(colors)} colors: {colors}')
self.graph.apply_colors(colors)
print('Breadth-first traversal:')
for i, node in enumerate(self.graph.breadth_first_traversal(self.nodes[1])):
print(f'{str(i + 1).rjust(2)}. {node}')
def test_illegal_graph_coloring(self) -> None:
colors = {'blue', 'black'}
print(f'Coloring graph with {len(colors)} colors: {colors}')
with self.assertRaises(ValueError):
self.graph.apply_colors(colors)
def test_shortest_path(self) -> None:
print(f'Shortest path from 5 => 6')
for i, node in enumerate(self.graph.shortest_path(self.nodes[5], self.nodes[6])):
print(f'{i}. {node}')
print(f'\nShortest path from 2 => 12')
for i, node in enumerate(self.graph.shortest_path(self.nodes[2], self.nodes[12])):
print(f'{i}. {node}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment