Skip to content

Instantly share code, notes, and snippets.

@KJTsanaktsidis
Created May 11, 2013 04:40
Show Gist options
  • Save KJTsanaktsidis/5558922 to your computer and use it in GitHub Desktop.
Save KJTsanaktsidis/5558922 to your computer and use it in GitHub Desktop.
A graph that can do dijkstra's algorithm and a little extra
import csv
import functools
import sys
import heapq
class AdjacencyGraph():
@functools.total_ordering
class SearchNode():
"""
We're just going to use this guy as an expando object to store path, dist, and known
These are sortable on cost
"""
def __init__(self, name, path=None, dist=sys.maxsize, known=False):
self.name = name
self.dist = dist
self.path = path
self.known = known
def __eq__(self, other):
return self.dist == other.dist
def __lt__(self, other):
return self.dist < other.dist
@functools.total_ordering
class MultiSearchNode():
"""
We're just going to use this guy as an expando object to store path, dist, and known
These are sortable on cost
"""
def __init__(self, name, known=False):
self.name = name
self.dist = {}
self.path = {}
self.known = known
def __eq__(self, other):
if len(self.dist) == 0 and len(other.dist) == 0:
return True
elif (len(self.dist) == 0 and not len(other.dist) == 0) or \
(not len(self.dist) == 0 and len(other.dist) == 0):
return False
else:
return min(self.dist.values()) == min(other.dist.values())
def __lt__(self, other):
if len(self.dist) == 0 and len(other.dist) == 0:
return False
elif len(self.dist) == 0 and not len(other.dist) == 0:
return False
elif not len(self.dist) == 0 and len(other.dist) == 0:
return True
else:
return min(self.dist.values()) < min(other.dist.values())
def __init__(self):
self.adjacency_list = dict()
def insert_node(self, name):
"""
Add a node called name to the adjacency list.
If name is already present, raise an ValueError
"""
if name in self.adjacency_list:
raise ValueError('{} is already present in the graph'.format(name))
self.adjacency_list[name] = list()
def insert_link(self, srcname, destname, cost):
"""
Add a directional ink to the adjacency list linking srcname and dstname
If either are not in the graph, raise a KeyError
"""
if not destname in self.adjacency_list:
raise KeyError('{} not in the graph'.format(destname))
#we'll get a KeyError() automatically here
self.adjacency_list[srcname].append((destname, cost))
def single_min_cost_search(self, srcname, destname):
"""
Search for a min cost path from srcname to dstname
Raise a KeyError if either are not present
"""
#if we don't check for this, it will be as if we searched for an unconnected node
if not srcname in self.adjacency_list.keys():
raise KeyError('{} not in the graph'.format(srcname))
if not destname in self.adjacency_list.keys():
raise KeyError('{} not in the graph'.format(destname))
#make a dictionary of all search node objects
all_verts = dict()
for v in self.adjacency_list.keys():
if v == srcname:
all_verts[v] = self.SearchNode(v, dist=0)
else:
all_verts[v] = self.SearchNode(v)
#and a heap containing just the unvisited ones
#this is a ref copy, so updating something in unvisited_verts updates it in all_verts
unvisited_verts = list(all_verts.values())
heapq.heapify(unvisited_verts)
while len(unvisited_verts) > 0:
#get smallest
cur_vert = heapq.heappop(unvisited_verts)
cur_vert.known = True
for vname, cost in self.adjacency_list[cur_vert.name]:
#we have the name of vertex from adjacency list, can update it in all_verts
v = all_verts[vname]
if v.known:
continue
if cur_vert.dist + cost < v.dist:
v.dist = cur_vert.dist + cost
v.path = cur_vert.name
#we've mutated stuff on the heap, so we need to sort it again
heapq.heapify(unvisited_verts)
#and now we need to return (list of names, total cost)
total_cost = all_verts[destname].dist
name_list = [destname]
prev_name = all_verts[destname].path
while prev_name != srcname:
name_list.insert(0, prev_name)
prev_name = all_verts[prev_name].path
name_list.insert(0, srcname)
return name_list, total_cost
def multi_min_cost_search(self, srcname, destname):
"""
Search for a dict of min cost paths from srcname to dstname
Each dict key is a number of nodes, and the value is the (path, cost) min for getting from
srcname to destname with that number of nodes
Raise a KeyError if either are not present
"""
#if we don't check for this, it will be as if we searched for an unconnected node
if not srcname in self.adjacency_list.keys():
raise KeyError('{} not in the graph'.format(srcname))
if not destname in self.adjacency_list.keys():
raise KeyError('{} not in the graph'.format(destname))
#make a dictionary of all search node objects
all_verts = {v: self.MultiSearchNode(v) for v in self.adjacency_list.keys()}
all_verts[srcname].dist[0] = 0
all_verts[srcname].path[0] = None
#and a heap containing just the unvisited ones
#this is a ref copy, so updating something in unvisited_verts updates it in all_verts
unvisited_verts = list(all_verts.values())
heapq.heapify(unvisited_verts)
while len(unvisited_verts) > 0:
#get smallest
cur_vert = heapq.heappop(unvisited_verts)
cur_vert.known = True
#clean up unwanted length indicies
cur_min = sys.maxsize
sorted_keys = sorted(cur_vert.dist.keys())
for k in sorted_keys:
if cur_vert.dist[k] < cur_min:
cur_min = cur_vert.dist[k]
else:
del cur_vert.dist[k]
del cur_vert.path[k]
for vname, cost in self.adjacency_list[cur_vert.name]:
#we have the name of vertex from adjacency list, can update it in all_verts
v = all_verts[vname]
if v.known:
continue
#update our neighbours for our path length
for k, dist in cur_vert.dist.items():
if not k + 1 in v.dist:
v.dist[k + 1] = sys.maxsize
if dist + cost < v.dist[k + 1]:
v.dist[k + 1] = dist + cost
v.path[k + 1] = cur_vert.name
heapq.heapify(unvisited_verts)
#get rid of any unwanted indicies again
cur_min = sys.maxsize
sorted_keys = sorted(all_verts[destname].dist.keys())
for k in sorted_keys:
if cur_vert.dist[k] < cur_min:
cur_min = cur_vert.dist[k]
else:
del cur_vert.dist[k]
del cur_vert.path[k]
#now prepare the return
rlist = {}
for k in all_verts[destname].dist.keys():
#and now we need to store (list of names, total cost)
total_cost = all_verts[destname].dist[k]
name_list = [destname]
prev_name = all_verts[destname].path[k]
i = k - 1
while prev_name != srcname:
name_list.insert(0, prev_name)
prev_name = all_verts[prev_name].path[i]
i -= 1
name_list.insert(0, srcname)
rlist[k] = (name_list, total_cost)
return rlist
def graph_from_csv(data_source):
"""
Generates an AdjacencyGraph from a csv stream.
Format is assumed to be src,dest,cost, and links are automatically bidirectional
data_source can be anything that has __iter__()
"""
reader = csv.reader(data_source)
graph = AdjacencyGraph()
for row in reader:
src = row[0].strip()
dest = row[1].strip()
cost = int(row[2])
try:
graph.insert_node(src)
except ValueError:
pass
try:
graph.insert_node(dest)
except ValueError:
pass
graph.insert_link(src, dest, cost)
graph.insert_link(dest, src, cost)
return graph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment