Skip to content

Instantly share code, notes, and snippets.

Last active July 20, 2020 09:45
Show Gist options
  • Save avinashselvam/c9704c36301d257e7b835ac7b29794cd to your computer and use it in GitHub Desktop.
Save avinashselvam/c9704c36301d257e7b835ac7b29794cd to your computer and use it in GitHub Desktop.
Hungarian Maximum Matching Algorithm
class Node:
Vertex / Node in a bipartite graph
id : int
Unique identifier within a set
set: int
0 means belongs to left set of bipartite graph 1 means right
label : int
A real value assigned under some condition
match : Node
Node in the other set to which this node is matched to
visitied : bool
To keep track of depth first search
Check if the node belongs to the left set
Check if the node belongs to the right set
Check if the node has been already matched
Convenience method to write to label
Mark node as visited during graph traversal
Mark node as unvisited during graph traversal
def __init__(self, uid, left_or_right): = uid
self.label = 0
self.set = left_or_right
self.match = None
self.visited = False
def __hash__(self):
return hash((self.set,
def __repr__(self):
return "<l/r:{}, id:{}, label:{}>".format(self.set,, self.label)
def is_left(self):
return self.set == 0
def is_right(self):
return self.set == 1
def is_matched(self):
return self.match is not None
def set_label(self, label):
self.label = label
def visit(self):
self.visited = True
def unvisit(self):
self.visited = False
class Edge:
Edge connecting two nodes in a bipartite graph
left : Node
Node in the left set
right : Node
Node in the right set
weight: int
Weight of the edge in graph theory terms
Check if the edge belongs to equality graph
i.e label of left node + label of right node = weight
def __init__(self, left_node, right_node, weight):
self.left = left_node
self.right = right_node
self.weight = weight
def __hash__(self):
return (,
def __repr__(self):
return "<l_id:{}, r_id:{}, w:{}>".format(,, self.weight)
def is_tight(self):
return (self.left.label + self.right.label == self.weight)
class Hungarian:
Implements Hungarian Maximum Matching Algorithm in a bipartite graph
cost : [[int]]
Cost matrix that specifies the weights of all edges in the bipartite graph
N : int
Number of nodes in either of the sets of the bipartite graph
X : [Node]
Left set of the bipartite graph
Y : [Node]
Right set of the bipartite graph
E : [[Edge]]
Adjacency matrix of the bipartite graph
Note : method names starting with _ should only be called on self
_add_edges(N, X, Y, cost)
Constructs E from N, X, Y
Assigns labels to nodes based on the equality graph condition
Set all nodes' visited as False to prepare for next DFS traversal
_alternating_dfs(root, path, augmenting_path, candidate_path)
DFS traversal of the graph to find augmenting path if not candidate path
Augments the existing matching with newly found augmenting path
Finds free node and begins alternating DFS from there
Checks if the current match is perfect or not
Givent the candidate path it updates the node labels so we can find an augmented path
main function that runs the algorithm
def __init__(self, cost):
assert len(cost) == len(cost[0]), "Only square cost matrix is supported"
self.cost = cost
self.N = len(cost)
self.X = [Node(i, 0) for i in range(self.N)]
self.Y = [Node(i, 1) for i in range(self.N)]
self._add_edges(self.N, self.X, self.Y, self.cost)
def _add_edges(self, N, X, Y, cost):
self.E = [[Edge(X[i], Y[j], cost[i][j]) for j in range(N)]for i in range(N)]
def _init_labels(self):
for i in range(self.N):
self.X[i].set_label(max([edge.weight for edge in self.E[i]]))
def _reset_visit_status(self):
for node in self.X: node.unvisit()
for node in self.Y: node.unvisit()
def _alternating_dfs(self, root, path, augmenting_path, candidate_path):
if root.visited: return
uid =
if root.is_left():
for edge in self.E[uid]:
if edge.is_tight():
if edge.right.is_matched(): self._alternating_dfs(edge.right, path+[edge.right], augmenting_path, candidate_path)
else: augmenting_path[0] = path + [edge.right]
elif root.is_right():
candidate_path[0] = path+[root.match]
self._alternating_dfs(root.match, path+[root.match], augmenting_path, candidate_path)
def _augment(self, path):
print("augmenting with: ", path)
i = 0
n = len(path)
while i < n:
node1, node2 = path[i], path[i+1]
node1.match = node2
node2.match = node1
i += 2
def _find_augmenting_path(self):
root = next(node for node in self.X if not node.is_matched())
augmenting_path = [None] # pass by reference array trick
candidate_path = [None] # pass by reference array trick
self._alternating_dfs(root, [root], augmenting_path, candidate_path)
return (augmenting_path[0], candidate_path[0])
def _perfect_match_not_found(self):
return False in [node.is_matched() for node in self.X]
def _update_node_labels(self, S, T):
delta = 10000000
notT = set(self.Y) - T
for left_node in S:
left_label = left_node.label
for right_node in notT:
right_label = right_node.label
weight = self.E[][].weight
delta = min(delta, left_label + right_label - weight)
print("updating labels of: ", S, T, "with: ", delta)
for node in S: node.set_label(node.label-delta)
for node in T: node.set_label(node.label+delta)
def match(self):
while self._perfect_match_not_found():
augmenting_path, candidate_path = self._find_augmenting_path()
if augmenting_path: self._augment(augmenting_path)
S = set(candidate_path[0::2])
T = set(candidate_path[1::2])
self._update_node_labels(S, T)
return self.X
cost = [
[2, 3, 4, 5],
[6, 5, 4, 8],
[5, 9, 2, 8],
[4, 6, 3, 1]
h = Hungarian(cost)
X = h.match()
print([(, for node in X])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment