Skip to content

Instantly share code, notes, and snippets.

@UrosOgrizovic
Last active December 5, 2020 09:20
Show Gist options
  • Save UrosOgrizovic/d7c9a472c579673f3021f21db8f99580 to your computer and use it in GitHub Desktop.
Save UrosOgrizovic/d7c9a472c579673f3021f21db8f99580 to your computer and use it in GitHub Desktop.
Find the shortest path in a graph using Bellman-Ford.
"""
find shortest path in graph via Bellman-Ford
Uros Ogrizovic
"""
import numpy as np
import matplotlib.pyplot as plt
def get_neighbors_of_node(node, adjacency_matrix):
"""
Returns a list of nodes that are neighbors
of the parameter node.
:param node:
:param adjacency_matrix:
:return: neighbors
"""
neighbors = []
for (i, x) in enumerate(adjacency_matrix[node]):
'''If the adjacency matrix has a 0 for two nodes,
that means they aren't connected. Any other number
represents the weight of the edge between the two
nodes, hence the condition is if ... != 0.'''
if x != 0:
neighbors.append(i)
return neighbors
def get_whose_neighbor(node, adjacency_matrix):
"""
Returns a list of nodes which the parameter
node is a neighbor of.
:param node:
:param adjacency_matrix:
:return: neighbor_of
"""
neighbor_of = []
for i in range(len(adjacency_matrix)):
'''If the adjacency matrix has a 0 for two nodes,
that means they aren't connected. Any other number
represents the weight of the edge between the two
nodes, hence the condition is if ... != 0.'''
if adjacency_matrix[i][node] != 0:
neighbor_of.append(i)
return neighbor_of
def get_node_values(adjacency_matrix, terminal_nodes):
"""
:param adjacency_matrix: the graph, represented as an adjacency matrix
:param terminal_nodes: indices of terminal nodes
:return:
"""
node_values = [-np.inf for i in range(len(adjacency_matrix))]
for node in terminal_nodes:
node_values[node] = 0
nodes_to_check = list(range(len(adjacency_matrix)))
while True:
if len(nodes_to_check) == 0:
return node_values
current_node = nodes_to_check.pop()
neighbor_of = get_whose_neighbor(current_node, adjacency_matrix)
'''instead of looking at the neighbors, look at whose neighbor it is - this way, the algorithm works for
directed graphs'''
for n in neighbor_of:
if node_values[n] < node_values[current_node]:
# support weighted graphs
node_values[n] = node_values[current_node] - adjacency_matrix[n][current_node]
def check_if_terminal_nodes_in_list(lst, terminal_nodes):
"""
Check if any of the terminal nodes are in lst and return the index of the node that is in lst.
:param lst:
:param terminal_nodes:
:return: index of terminal node that is in lst, -1 if no terminal node is in lst
"""
for node in terminal_nodes:
if node in lst:
return node
return -1
def get_path_price(path, adjacency_matrix):
"""
:param path:
:param adjacency_matrix:
:return: path_sum
"""
path_sum = 0
for i in range(len(path)-1):
path_sum += adjacency_matrix[path[i]][path[i+1]]
return path_sum
def stringify_path(path):
"""
Stringify path (e.g. [0, 1, 2] to 0-1-2).
:param path:
:return: str_path
"""
if type(path) != list: # terminal node, trivial case
return path
str_path = "".join([str(node) + "-" for node in path])
str_path = str_path[:-1]
return str_path
def find_shortest_path(start_node, adjacency_matrix, terminal_nodes):
"""
:param start_node: index of starting node
:param adjacency_matrix: the graph, represented as an adjacency matrix
:param terminal_nodes: indices of terminal nodes
:return:
"""
if start_node < 0:
raise ValueError("Failed to find shortest path due to negative value for start node")
if start_node in terminal_nodes:
return start_node, 0
for node in terminal_nodes:
if node < 0:
raise ValueError("Failed to find shortest path due to negative values for terminal nodes")
paths = [] # a list of lists, where each sublist is a path of node indices
neighbors = []
num_of_iterations = len(adj_matrix) - 1 # repeat at most |V| - 1 times
while True:
# the lengths are important for updating later, after the for loop
curr_paths_len = len(paths)
curr_neighbors_len = len(neighbors)
num_of_iterations -= 1
modified = False # no need to do |V| - 1 iterations every time
if curr_neighbors_len == 0: # base case
modified = True
neighbors = get_neighbors_of_node(start_node, adjacency_matrix)
for n in neighbors:
paths.append([start_node, n])
else:
for i in range(curr_neighbors_len):
current_neighbors = get_neighbors_of_node(neighbors[i], adjacency_matrix)
for j in range(len(current_neighbors)):
temp_lst = paths[i].copy() # so as not to modify paths when modifying temp_lst
# don't append to paths that end in a terminal node, just leave them as-is
if paths[i][-1] not in terminal_nodes:
modified = True
temp_lst.append(current_neighbors[j])
paths.append(temp_lst)
if len(current_neighbors) == 0:
paths.append(paths[i]) # so as not to lose paths[i] when removing explored paths
# remove explored paths from paths
paths = paths[curr_paths_len:]
# add nodes that should be explored to neighbors
for path in paths:
neighbors.append(int(path[-1]))
# remove explored neighbors from neighbors
neighbors = neighbors[curr_neighbors_len:]
''' Calculate path prices if:
1. |V| - 1 iterations have been completed
or
2. if there have been no modifications in the previous iteration - there's no need
to keep copying the same paths until |V| - 1 iterations are completed, so this boolean flag
speeds up the process
'''
if num_of_iterations == 0 or not modified:
path_prices = []
for path in paths:
# only get path prices for paths that contain a terminal node
if check_if_terminal_nodes_in_list(path, terminal_nodes) != -1:
path_prices.append(get_path_price(path, adjacency_matrix))
else:
# if path contains no terminal nodes, it has an infinite price
path_prices.append(np.inf)
min_price = min(path_prices)
if min_price == np.inf:
return 'None', 'inf'
idx = path_prices.index(min_price)
return paths[idx], path_prices[idx]
if __name__ == "__main__":
# A, B, C, D, E
undirected_graph = [[0, 1, 1, 0, 0], [1, 0, 0, 1, 1], [1, 0, 0, 1, 0],
[0, 1, 1, 0, 1], [0, 1, 0, 1, 0]] # see "undirected_graph.png"
# A, B, C, D, E
directed_graph_weighted = [[0, 4, 2, 0, 0], [0, 0, 0, -2, 2], [0, 0, 0, 2, 0],
[0, 0, 0, 0, 1], [0, 0, 0, 0, 0]] # see "directed_graph_weighted.png"
# A, B, C, D, E
directed_graph_1 = [[0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0],
[0, 0, 0, 0, 1], [0, 0, 0, 0, 0]] # see "directed_graph_1.png"
# A, B, C, D, E, T
directed_graph_2 = [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0]] # see "directed_graph_2.png"
# X, T, B, G, D, C, E, M
directed_graph_3 = [[0, 1, 1, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 1, 0, 0]] # see "directed_graph_3.png"
adj_matrix = directed_graph_1
terminal_nodes = [1, len(adj_matrix) - 1]
start_node = 2
v = get_node_values(adj_matrix, terminal_nodes)
print("Node values:", v)
shortest_path, price = find_shortest_path(start_node, adj_matrix, terminal_nodes)
print("Shortest path:", stringify_path(shortest_path), "| Price =", price)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment