Skip to content

Instantly share code, notes, and snippets.

@spranesh
Created May 17, 2011 08:17
Show Gist options
  • Save spranesh/976142 to your computer and use it in GitHub Desktop.
Save spranesh/976142 to your computer and use it in GitHub Desktop.
Directed Graph class
""" Implementation of a simple directed graph with no weights.
Test by either running python on this file,
or by calling nosetests on this file.
"""
import unittest
import collections
class DirectedGraph:
def __init__(self):
self.graph = collections.defaultdict(lambda: list())
return
def __repr__(self):
return repr(self.graph)
def __str__(self):
return str(self.graph)
def AddEdge(self, a, b):
""" Add directed edge a -> b """
assert(b not in self.graph[a]) # No multi graphs allowed
self.graph[a].append(b)
return
def HasEdge(self, a, b):
return b in self.graph[a]
def RemoveEdge(self, a, b):
""" O(neighbours(a)) operation."""
assert(b in self.graph[a])
self.graph[a].remove(b)
assert(b not in self.graph[a])
def DFS(self, start):
assert(self.graph.has_key(start))
visited = collections.defaultdict(lambda: False)
q = collections.deque([start])
visited[start] = True
while len(q) > 0:
current_node = q.pop()
neighbours = self.graph[current_node]
q.extend([n for n in neighbours if not visited[n]])
for n in neighbours:
visited[n] = True
yield current_node
def GetReachabilityGraph(self):
""" An O(n^2) implementation """
reachability = {}
for node in self.graph:
dfs = list(self.DFS(node))
reachability[node] = dfs[1:] # we don't want the node itself
return reachability
class TestDirectedGraph(unittest.TestCase):
def setUp(self):
""" Set up method. This graph, g is the graph used in all cases.
Run before calling each test. i.e, each test is independent of other
tests."""
self.g = DirectedGraph()
self.g.AddEdge(1, 2)
self.g.AddEdge(1, 3)
self.g.AddEdge(2, 1)
self.g.AddEdge(4, 5)
return
def testHasEdge(self):
assert(self.g.HasEdge(1, 2))
assert(self.g.HasEdge(1, 3))
assert(self.g.HasEdge(2, 1))
assert(self.g.HasEdge(4, 5))
assert(not self.g.HasEdge(5, 6))
def testAddEdge(self):
assert(not self.g.HasEdge(5, 6))
self.g.AddEdge(5, 6)
assert(self.g.HasEdge(5, 6))
def testRemoveEdge(self):
assert(self.g.HasEdge(1, 2))
self.g.RemoveEdge(1, 2)
assert(not self.g.HasEdge(1, 2))
def testDFS(self):
# Should return [1, 2, 3] in some order
d1 = list(self.g.DFS(1))
# We don't know the order of d1
assert(len(d1) == 3)
assert(1 in d1)
assert(2 in d1)
assert(3 in d1)
# Should also return [1, 2, 3] in some order.
d2 = list(self.g.DFS(2))
assert(len(d2) == 3)
assert(1 in d2)
assert(2 in d2)
assert(3 in d2)
# Should return [3]
d3 = list(self.g.DFS(3))
assert(len(d3) == 1)
assert(d3[0] == 3)
# Should return [4, 5] in some order
d4 = list(self.g.DFS(4))
assert(len(d4) == 2)
assert(4 in d4)
assert(5 in d4)
# Should return [5]
d5 = list(self.g.DFS(5))
assert(len(d5) == 1)
assert(5 in d5)
return
def testGetReachabilityGraph(self):
# Try on a simpler graph
g = DirectedGraph()
g.AddEdge('f', 'g')
g.AddEdge('g', 'f')
g.AddEdge('h', 'f')
reachability_graph = g.GetReachabilityGraph()
reachability_graph['f'] = ['g']
reachability_graph['g'] = ['f']
assert(len(reachability_graph['h']) == 2)
assert('f' in reachability_graph['h'])
assert('g' in reachability_graph['h'])
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment