Skip to content

Instantly share code, notes, and snippets.

@MasanoriYamada
Created August 2, 2023 03:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MasanoriYamada/72405515264749df02ba392f16810e12 to your computer and use it in GitHub Desktop.
Save MasanoriYamada/72405515264749df02ba392f16810e12 to your computer and use it in GitHub Desktop.
linear_sum_assignment in pytorch
import torch
import random
from scipy.optimize import linear_sum_assignment as linear_sum_assignment_scipy
import time
def augmenting_path(cost, u, v, path, row4col, i):
device = cost.device
minVal = 0
num_remaining = cost.shape[1]
remaining = torch.arange(cost.shape[1] - 1, -1, -1, device=device)
SR = torch.full((cost.shape[0],), False, device=device)
SC = torch.full((cost.shape[1],), False, device=device)
shortestPathCosts = torch.full((cost.shape[1],), float('inf'), device=device)
sink = -1
while sink == -1:
index = -1
lowest = float('inf')
SR[i] = True
for it in torch.arange(num_remaining, device=device):
j = remaining[it]
r = minVal + cost[i, j] - u[i] - v[j]
if r < shortestPathCosts[j]:
path[j] = i
shortestPathCosts[j] = min(r, shortestPathCosts[j])
if (shortestPathCosts[j] < lowest) or ((shortestPathCosts[j] == lowest) and (row4col[j] == -1)):
index = it
lowest = min(lowest, shortestPathCosts[j])
minVal = lowest
if minVal == float('inf'): # infeasible cost matrix
sink = -1
break
j = remaining[index]
if row4col[j] == -1:
sink = j.clone()
else:
i = row4col[j]
SC[j] = True
num_remaining -= 1
remaining[index] = remaining[num_remaining]
return sink, minVal, remaining, SR, SC, shortestPathCosts, path
def linear_sum_assignment(cost, maximize=False):
with torch.no_grad():
if maximize:
cost = -cost
device = cost.device
transpose = cost.shape[1] < cost.shape[0]
if transpose:
cost = cost.T
u = torch.full((cost.shape[0],), 0., device=device)
v = torch.full((cost.shape[1],), 0., device=device)
path = torch.full((cost.shape[1],), -1, device=device)
col4row = torch.full((cost.shape[0],), -1, device=device)
row4col = torch.full((cost.shape[1],), -1, device=device)
for curRow in torch.arange(cost.shape[0], device=device):
# jがズレている
j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)
u[curRow] += minVal
mask = SR & (torch.arange(cost.shape[0], device=device) != curRow)
u[mask] += minVal - shortestPathCosts[col4row][mask]
v[SC] += shortestPathCosts[SC] - minVal
while True:
i = path[j]
row4col[j] = i
col4row_tmp = col4row.clone()
col4row[i], j = j, col4row_tmp[i]
if i == curRow:
break
if transpose:
v = torch.argsort(col4row)
return col4row[v], v
else:
return torch.arange(cost.shape[0], device=device), col4row
def main():
torch.manual_seed(0)
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for t in range(5):
print('=====')
shape = (random.randint(10, 20), random.randint(10, 20))
cost = torch.rand(shape, device=device)
time_sta = time.time()
row_ind_1, col_ind_1 = linear_sum_assignment_scipy(cost.cpu().numpy()) # scipy
time_sta1 = time.time()
print(f'scipy: {time_sta1 - time_sta}')
row_ind_2, col_ind_2 = linear_sum_assignment(cost) # pytorch
time_sta2 = time.time()
print(f'torch: {time_sta2 - time_sta1}')
print('{:5} {}'.format(t,
(row_ind_1.tolist() == row_ind_2.tolist()) and
(col_ind_1.tolist() == col_ind_2.tolist())
))
if __name__ == '__main__':
main()
@MasanoriYamada
Copy link
Author

Results are consistent with scipy, but torch is too slow.

=====
scipy: 0.00018858909606933594
torch: 0.49675846099853516
    0 True
=====
scipy: 0.000171661376953125
torch: 0.09160661697387695
    1 True
=====
scipy: 0.00014448165893554688
torch: 0.22105908393859863
    2 True
=====
scipy: 0.00015854835510253906
torch: 0.1297318935394287
    3 True
=====
scipy: 0.00012874603271484375
torch: 0.22979378700256348
    4 True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment