Skip to content

Instantly share code, notes, and snippets.

@fanjin-z
Last active May 13, 2018 00:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fanjin-z/b0233efb09602559519e64c805b507f6 to your computer and use it in GitHub Desktop.
Save fanjin-z/b0233efb09602559519e64c805b507f6 to your computer and use it in GitHub Desktop.
An Python Implementation of Hungarian Algorithm. This implementation is based on http://csclab.murraystate.edu/~bob.pilgrim/445/munkres.html
import numpy as np
class hungrary():
def __init__(self, weight):
self.n = weight.shape[0]
self.w = np.copy(weight)
# cost matrix
self.c = np.copy(weight)
self.m = np.zeros((self.n, self.n), dtype=int)
# record row and col covers
self.RowCover = np.zeros((self.n), dtype=bool)
self.ColCover = np.zeros((self.n), dtype=bool)
# record augment paths
self.path = np.zeros((2*self.n, 2), dtype=int)
# main program, run the algo through steps
def run_hungrary(self):
done = False
step = 1
while not done:
if step == 1:
step = self.step1()
elif step == 2:
step = self.step2()
elif step == 3:
step = self.step3()
elif step == 4:
step = self.step4()
elif step == 5:
step = self.step5()
elif step == 6:
step = self.step6()
elif step == 7:
done = True
# Each row subtract smallest elements
def step1(self):
self.c -= np.min(self.c, axis=1, keepdims=True)
return 2
# star zeros
def step2(self):
for u in range(self.n):
for v in range(self.n):
if self.c[u,v] == 0 and not self.RowCover[u] and not self.ColCover[v]:
self.m[u, v] = 1
self.RowCover[u] = True
self.ColCover[v] = True
break
self.clear_covers()
return 3
# cover cols with starred zeros. check if done
def step3(self):
for u in range(self.n):
for v in range(self.n):
if self.m[u, v] == 1:
self.ColCover[v] = True
colcnt = np.sum(self.ColCover)
if colcnt >= self.n:
return 7
else:
return 4
# find noncovered zero and prime it (starred as 2)
def step4(self):
while True:
row, col = self.find_a_zero()
if row == -1:
return 6
else:
self.m[row, col] = 2
if self.star_in_row(row):
col = self.find_star_in_row(row)
self.RowCover[row] = True
self.ColCover[col] = False
else:
self.path_row_0 = row
self.path_col_0 = col
return 5
# use augment algo to increase matches
def step5(self):
done = False
self.path_count = 1
self.path[self.path_count-1, 0] = self.path_row_0
self.path[self.path_count-1, 1] = self.path_col_0
while not done:
row = self.find_star_in_col(self.path[self.path_count-1, 1])
if row > -1:
self.path_count += 1
self.path[self.path_count-1, 0] = row
self.path[self.path_count-1, 1] = self.path[self.path_count-2, 1]
else:
done = True
if not done:
col = self.find_prime_in_row(self.path[self.path_count-1, 0])
self.path_count += 1
self.path[self.path_count-1, 0] = self.path[self.path_count-2, 0]
self.path[self.path_count-1, 1] = col
self.augment_path()
self.clear_covers()
self.erase_prime()
return 3
# add minval val to double covered elements and subtract it to noncovered elements
def step6(self):
minval = self.find_smallest()
for u in range(self.n):
for v in range(self.n):
if self.RowCover[u]:
self.c[u,v] += minval
if not self.ColCover[v]:
self.c[u,v] -= minval
return 4
# find first uncovered zero
def find_a_zero(self):
for u in range(self.n):
for v in range(self.n):
if self.c[u,v] == 0 and not self.RowCover[u] and not self.ColCover[v]:
return u, v
return -1, -1
def star_in_row(self, row):
for v in range(self.n):
if self.m[row, v] == 1:
return True
return False
def find_star_in_row(self, row):
for v in range(self.n):
if self.m[row, v] == 1:
return v
return -1
def find_star_in_col(self, col):
for u in range(self.n):
if self.m[u, col] == 1:
return u
return -1
def find_prime_in_row(self, row):
for v in range(self.n):
if self.m[row, v] == 2:
return v
return -1
def augment_path(self):
for p in range(self.path_count):
if self.m[self.path[p,0], self.path[p,1]] == 1:
self.m[self.path[p,0], self.path[p,1]] = 0
else:
self.m[self.path[p,0], self.path[p,1]] = 1
def clear_covers(self):
self.RowCover = np.zeros((self.n), dtype=bool)
self.ColCover = np.zeros((self.n), dtype=bool)
def erase_prime(self):
for u in range(self.n):
for v in range(self.n):
if self.m[u,v] == 2:
self.m[u,v] = 0
def find_smallest(self):
minval = np.max(self.c)
for u in range(self.n):
for v in range(self.n):
if self.c[u,v] < minval and not self.RowCover[u] and not self.ColCover[v]:
minval = self.c[u,v]
return minval
from scipy.optimize import linear_sum_assignment
# Check correctness
distance_mat = np.random.rand(10,10)
H = hungrary(distance_mat)
H.run_hungrary()
# EMD computed by my hungarian algorithm
np.sum(H.w * H.m)
# EMD computed by scipy linear_sum_assignment# EMD c
row_ind, col_ind = linear_sum_assignment(distance_mat)
distance_mat[row_ind, col_ind].sum()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment