Skip to content

Instantly share code, notes, and snippets.

@cgthayer
Created March 15, 2017 06:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgthayer/f433ac8ad2cebaf7b3d6a90829068903 to your computer and use it in GitHub Desktop.
Save cgthayer/f433ac8ad2cebaf7b3d6a90829068903 to your computer and use it in GitHub Desktop.
Djikstra in python for source to all others and source to single destination
import copy
Infinity = 1e6
class GraphU:
"""Undirected graph, with edge cost and nodes have string names.
Okay to use int or float cost, but must be positive.
Assumes only one edge between any node n1 and n2.
"""
def __init__(self, edges=()):
self.nodes = set()
self.edges = {}
for (n1, n2, cost) in edges:
self.add_edge(n1, n2, cost)
def add_edge(self, n1, n2, cost):
self.nodes.add(n1)
self.nodes.add(n2)
if n1 < n2:
self.edges[(n1, n2)] = cost
else:
self.edges[(n2, n1)] = cost
def get_cost(self, n1, n2):
"""Cost or None if not connected"""
key = (n1, n2) if n1 < n2 else (n2, n1)
return self.edges.get(key, Infinity)
def neighbors(self, node):
nset = set() # neighbor set
for k in self.edges.keys():
(n1, n2) = k
if n1 == node:
nset.add(n2)
if n2 == node:
nset.add(n1)
return(list(nset))
@classmethod
def djikstra(cls, graph, start, dest=None):
"""Find shortest cost path from start to dest (or all destinations).
Uses list instead of min-heap (which would be better)
"""
print("graph::\n" + repr(graph.edges))
# { node: previous_hop } from start node to node
previous = { node: node for node in graph.nodes if node != start }
# { node: cost } from start node to node
hop_cost = {
node: graph.get_cost(start, node)
for node in graph.nodes if node != start
}
unvisited = [n for n in graph.nodes if n != start]
def smart_cost(n):
return hop_cost.get(n, Infinity)
while len(unvisited):
u = min(unvisited, key=smart_cost)
u_cost = hop_cost.get(u, Infinity)
if dest is not None and u == dest:
break
unvisited.remove(u)
for n in graph.neighbors(u):
if n == start:
continue
alt_cost = u_cost + graph.get_cost(u, n)
if alt_cost < hop_cost[n]:
previous[n] = u
hop_cost[n] = alt_cost
print("previous::\n" + repr(previous))
print("hop_cost::\n" + repr(hop_cost))
if dest is None:
return
path = []
cur = dest
while True:
path.append(cur)
prev = previous[cur]
if cur == prev:
break
cur = prev
path.append(start)
return(list(reversed(path)), hop_cost[dest])
graph = GraphU(
edges=(
('a', 'b', 20),
('a', 'e', 5),
('b', 'd', 30),
('b', 'c', 10),
('c', 'd', 10),
('c', 'e', 20),
)
)
print(GraphU.djikstra(graph, 'a'))
print(GraphU.djikstra(graph, 'a', 'd'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment