Skip to content

Instantly share code, notes, and snippets.

@f0lie
Last active August 18, 2021 09:11
Show Gist options
  • Save f0lie/b9a57be922f02671dd95a18acc71f0ad to your computer and use it in GitHub Desktop.
Save f0lie/b9a57be922f02671dd95a18acc71f0ad to your computer and use it in GitHub Desktop.
Python 3: Clean implementation of Heapq Dijsktra, Bellman-Ford, and SPFA
from collections import defaultdict, deque
import heapq
from typing import OrderedDict
def create_graph(matrix):
graph = defaultdict(list)
for row in range(len(matrix)):
for col in range(len(matrix[0])):
if matrix[row][col] > 0:
graph[row].append([col, matrix[row][col]])
return graph
def get_path(path, source, end):
current = end
found_path = [current]
while current != source:
current = path[current]
found_path.append(current)
return found_path[::-1]
def dijsktra(graph, source):
# https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm
# https://cs.stackexchange.com/questions/118388/dijkstra-without-decrease-key
# min path to i-th node from source
distance = [float("inf")] * len(graph)
distance[0] = 0
# contains previous node pointing to i-th node
path = [-1] * len(graph)
# (distance, node), heap[0] contains shortest distance so far
heap = [(0, source)]
while heap:
dist_from_source, node = heapq.heappop(heap)
# Exit path early if there is no way path can improve on answer
if dist_from_source > distance[node]:
continue
for neighbor, weight in graph[node]:
# Update dist if a shorter path was found than stored currently
if distance[node] + weight < distance[neighbor]:
distance[neighbor] = distance[node] + weight
path[neighbor] = node
heapq.heappush(heap, (distance[node] + weight, neighbor))
return distance, path
def bellman_ford(graph, source):
# Bellman Ford can be thought of as brute forcing to find min distance by checking
# all of the edges repeatably by the number of vertexs
distance = [float("inf")] * len(graph)
distance[source] = 0
path = [-1] * len(graph)
# at i step, distances contain shortest path at most i length
for _ in range(len(graph)-1):
for frm, neighbors in graph.items():
for to, weight in neighbors:
if distance[to] > distance[frm] + weight:
distance[to] = distance[frm] + weight
path[to] = frm
return distance, path
def spfa(graph, source):
# https://en.wikipedia.org/wiki/Shortest_Path_Faster_Algorithm
distance = [float("inf")] * len(graph)
distance[source] = 0
path = [-1] * len(graph)
# OrderedDict is used because appending is ordered with O(1) and lookup is O(1), values are ignored
queue = OrderedDict()
queue[source] = 0
while queue:
current, _ = queue.popitem()
for neighbor, weight in graph[current]:
if distance[neighbor] > distance[current] + weight:
distance[neighbor] = distance[current] + weight
path[neighbor] = current
if neighbor not in queue:
queue[neighbor] = None
return distance, path
def spfa_2(graph, source):
# A variation of spfa using a simpler dict without using odd ball OrderDict for O(1) appending and lookup
distance = {source: 0}
path = [-1] * len(graph)
queue = deque([source])
while queue:
current = queue.popleft()
for neighbor, weight in graph[current]:
if neighbor not in distance or distance[neighbor] > distance[current] + weight:
distance[neighbor] = distance[current] + weight
path[neighbor] = current
queue.append(neighbor)
return distance, path
if __name__ == "__main__":
# Input taken from here.
# https://www.geeksforgeeks.org/dijkstras-shortest-path-algorithm-greedy-algo-7/
input_graph = [[0, 4, 0, 0, 0, 0, 0, 8, 0],
[4, 0, 8, 0, 0, 0, 0, 11, 0],
[0, 8, 0, 7, 0, 4, 0, 0, 2],
[0, 0, 7, 0, 9, 14, 0, 0, 0],
[0, 0, 0, 9, 0, 10, 0, 0, 0],
[0, 0, 4, 14, 10, 0, 2, 0, 0],
[0, 0, 0, 0, 0, 2, 0, 1, 6],
[8, 11, 0, 0, 0, 0, 1, 0, 7],
[0, 0, 2, 0, 0, 0, 6, 7, 0]
]
graph = create_graph(input_graph)
distance, path = dijsktra(graph, 0)
print("Path from 0 to 8", get_path(path, 0, 8))
print("Distance from 0 to 8:", distance[8])
distance, path = bellman_ford(graph, 0)
print("Path from 0 to 8", get_path(path, 0, 8))
print("Distance from 0 to 8:", distance[8])
distance, path = spfa(graph, 0)
print("Path from 0 to 8", get_path(path, 0, 8))
print("Distance from 0 to 8:", distance[8])
distance, path = spfa_2(graph, 0)
print("Path from 0 to 8", get_path(path, 0, 8))
print("Distance from 0 to 8:", distance[8])
"""
Path from 0 to 8 [0, 1, 2, 8]
Distance from 0 to 8: 14
Path from 0 to 8 [0, 1, 2, 8]
Distance from 0 to 8: 14
Path from 0 to 8 [0, 1, 2, 8]
Distance from 0 to 8: 14
Path from 0 to 8 [0, 1, 2, 8]
Distance from 0 to 8: 14
"""
@f0lie
Copy link
Author

f0lie commented Aug 18, 2021

I wrote these implementations because I felt like many of the implementations out there weren't clear and clean. I put some of these algorithms into leetcode questions like Network Delay Time so I know they are good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment