Skip to content

Instantly share code, notes, and snippets.

@avinashselvam
Last active July 20, 2020 09:45
Show Gist options
  • Save avinashselvam/c9704c36301d257e7b835ac7b29794cd to your computer and use it in GitHub Desktop.
Save avinashselvam/c9704c36301d257e7b835ac7b29794cd to your computer and use it in GitHub Desktop.
Hungarian Maximum Matching Algorithm
class Node:
"""
Vertex / Node in a bipartite graph
Attributes
----------
id : int
Unique identifier within a set
set: int
0 means belongs to left set of bipartite graph 1 means right
label : int
A real value assigned under some condition
match : Node
Node in the other set to which this node is matched to
visitied : bool
To keep track of depth first search
Methods
-------
is_left()
Check if the node belongs to the left set
is_right()
Check if the node belongs to the right set
is_matched()
Check if the node has been already matched
set_label()
Convenience method to write to label
visit()
Mark node as visited during graph traversal
unvisit()
Mark node as unvisited during graph traversal
"""
def __init__(self, uid, left_or_right):
self.id = uid
self.label = 0
self.set = left_or_right
self.match = None
self.visited = False
def __hash__(self):
return hash((self.set, self.id))
def __repr__(self):
return "<l/r:{}, id:{}, label:{}>".format(self.set, self.id, self.label)
def is_left(self):
return self.set == 0
def is_right(self):
return self.set == 1
def is_matched(self):
return self.match is not None
def set_label(self, label):
self.label = label
def visit(self):
self.visited = True
def unvisit(self):
self.visited = False
class Edge:
"""
Edge connecting two nodes in a bipartite graph
Attributes
----------
left : Node
Node in the left set
right : Node
Node in the right set
weight: int
Weight of the edge in graph theory terms
Methods
-------
is_tight()
Check if the edge belongs to equality graph
i.e label of left node + label of right node = weight
"""
def __init__(self, left_node, right_node, weight):
self.left = left_node
self.right = right_node
self.weight = weight
def __hash__(self):
return (self.left.id, self.right.id)
def __repr__(self):
return "<l_id:{}, r_id:{}, w:{}>".format(self.left.id, self.right.id, self.weight)
def is_tight(self):
return (self.left.label + self.right.label == self.weight)
class Hungarian:
"""
Implements Hungarian Maximum Matching Algorithm in a bipartite graph
Attributes
----------
cost : [[int]]
Cost matrix that specifies the weights of all edges in the bipartite graph
N : int
Number of nodes in either of the sets of the bipartite graph
X : [Node]
Left set of the bipartite graph
Y : [Node]
Right set of the bipartite graph
E : [[Edge]]
Adjacency matrix of the bipartite graph
Note : method names starting with _ should only be called on self
Methods
-------
_add_edges(N, X, Y, cost)
Constructs E from N, X, Y
_init_labels()
Assigns labels to nodes based on the equality graph condition
_reset_visit_status()
Set all nodes' visited as False to prepare for next DFS traversal
_alternating_dfs(root, path, augmenting_path, candidate_path)
DFS traversal of the graph to find augmenting path if not candidate path
_augment(path)
Augments the existing matching with newly found augmenting path
_find_augmenting_path()
Finds free node and begins alternating DFS from there
_perfect_match_not_found()
Checks if the current match is perfect or not
_update_node_labels()
Givent the candidate path it updates the node labels so we can find an augmented path
match()
main function that runs the algorithm
"""
def __init__(self, cost):
assert len(cost) == len(cost[0]), "Only square cost matrix is supported"
self.cost = cost
self.N = len(cost)
self.X = [Node(i, 0) for i in range(self.N)]
self.Y = [Node(i, 1) for i in range(self.N)]
self._add_edges(self.N, self.X, self.Y, self.cost)
self._init_labels()
def _add_edges(self, N, X, Y, cost):
self.E = [[Edge(X[i], Y[j], cost[i][j]) for j in range(N)]for i in range(N)]
def _init_labels(self):
for i in range(self.N):
self.X[i].set_label(max([edge.weight for edge in self.E[i]]))
self.Y[i].set_label(0)
def _reset_visit_status(self):
for node in self.X: node.unvisit()
for node in self.Y: node.unvisit()
def _alternating_dfs(self, root, path, augmenting_path, candidate_path):
if root.visited: return
root.visit()
uid = root.id
if root.is_left():
for edge in self.E[uid]:
if edge.is_tight():
if edge.right.is_matched(): self._alternating_dfs(edge.right, path+[edge.right], augmenting_path, candidate_path)
else: augmenting_path[0] = path + [edge.right]
elif root.is_right():
candidate_path[0] = path+[root.match]
self._alternating_dfs(root.match, path+[root.match], augmenting_path, candidate_path)
def _augment(self, path):
print("augmenting with: ", path)
i = 0
n = len(path)
while i < n:
node1, node2 = path[i], path[i+1]
node1.match = node2
node2.match = node1
i += 2
def _find_augmenting_path(self):
root = next(node for node in self.X if not node.is_matched())
augmenting_path = [None] # pass by reference array trick
candidate_path = [None] # pass by reference array trick
self._alternating_dfs(root, [root], augmenting_path, candidate_path)
self._reset_visit_status()
return (augmenting_path[0], candidate_path[0])
def _perfect_match_not_found(self):
return False in [node.is_matched() for node in self.X]
def _update_node_labels(self, S, T):
delta = 10000000
notT = set(self.Y) - T
for left_node in S:
left_label = left_node.label
for right_node in notT:
right_label = right_node.label
weight = self.E[left_node.id][right_node.id].weight
delta = min(delta, left_label + right_label - weight)
print("updating labels of: ", S, T, "with: ", delta)
for node in S: node.set_label(node.label-delta)
for node in T: node.set_label(node.label+delta)
def match(self):
while self._perfect_match_not_found():
augmenting_path, candidate_path = self._find_augmenting_path()
if augmenting_path: self._augment(augmenting_path)
else:
S = set(candidate_path[0::2])
T = set(candidate_path[1::2])
self._update_node_labels(S, T)
return self.X
# TESTING
cost = [
[2, 3, 4, 5],
[6, 5, 4, 8],
[5, 9, 2, 8],
[4, 6, 3, 1]
]
h = Hungarian(cost)
X = h.match()
print([(node.id, node.match.id) for node in X])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment