Skip to content

Instantly share code, notes, and snippets.

@wisnunugroho21
Last active April 2, 2024 17:22
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wisnunugroho21/e020b5fca1a93d5441bc7b4319e191cf to your computer and use it in GitHub Desktop.
Save wisnunugroho21/e020b5fca1a93d5441bc7b4319e191cf to your computer and use it in GitHub Desktop.
Pytorch implementation of Hungarian Algorithm
# Pytorch implementation of Hungarian Algorithm
# Inspired from here : https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15
# Despite my effort to parallelize the code, there is still some sequential workflows in this code
from typing import Tuple
import torch
from torch import Tensor
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def min_zero_row(zero_mat: Tensor) -> Tuple[Tensor, Tensor]:
sum_zero_mat = zero_mat.sum(1)
sum_zero_mat[sum_zero_mat == 0] = 9999
zero_row = sum_zero_mat.min(0)[1]
zero_column = zero_mat[zero_row].nonzero()[0]
zero_mat[zero_row, :] = False
zero_mat[:, zero_column] = False
mark_zero = torch.tensor([[zero_row, zero_column]], device = device)
return zero_mat, mark_zero
def mark_matrix(mat: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
zero_bool_mat = (mat == 0)
zero_bool_mat_copy = zero_bool_mat.clone()
marked_zero = torch.tensor([], device = device)
while (True in zero_bool_mat_copy):
zero_bool_mat_copy, mark_zero = min_zero_row(zero_bool_mat_copy)
marked_zero = torch.concat([marked_zero, mark_zero], dim = 0)
marked_zero_row = marked_zero[:, 0]
marked_zero_col = marked_zero[:, 1]
arange_index_row = torch.arange(mat.shape[0], dtype=torch.float, device = device).unsqueeze(1)
repeated_marked_row = marked_zero_row.repeat(mat.shape[0], 1)
bool_non_marked_row = torch.all(arange_index_row != repeated_marked_row, dim = 1)
non_marked_row = arange_index_row[bool_non_marked_row].squeeze()
non_marked_mat = zero_bool_mat[non_marked_row.long(), :]
marked_cols = non_marked_mat.nonzero().unique()
is_need_add_row = True
while is_need_add_row:
repeated_non_marked_row = non_marked_row.repeat(marked_zero_row.shape[0], 1)
repeated_marked_cols = marked_cols.repeat(marked_zero_col.shape[0], 1)
first_bool = torch.all(marked_zero_row.unsqueeze(1) != repeated_non_marked_row, dim = 1)
second_bool = torch.any(marked_zero_col.unsqueeze(1) == repeated_marked_cols, dim = 1)
addit_non_marked_row = marked_zero_row[first_bool & second_bool]
if addit_non_marked_row.shape[0] > 0:
non_marked_row = torch.concat([non_marked_row.reshape(-1), addit_non_marked_row[0].reshape(-1)])
else:
is_need_add_row = False
repeated_non_marked_row = non_marked_row.repeat(mat.shape[0], 1)
bool_marked_row = torch.all(arange_index_row != repeated_non_marked_row, dim = 1)
marked_rows = arange_index_row[bool_marked_row].squeeze(0)
return marked_zero, marked_rows, marked_cols
def adjust_matrix(mat: Tensor, cover_rows: Tensor, cover_cols: Tensor) -> Tensor:
bool_cover = torch.zeros_like(mat)
bool_cover[cover_rows.long()] = True
bool_cover[:, cover_cols.long()] = True
non_cover = mat[bool_cover != True]
min_non_cover = non_cover.min()
mat[bool_cover != True] = mat[bool_cover != True] - min_non_cover
double_bool_cover = torch.zeros_like(mat)
double_bool_cover[cover_rows.long(), cover_cols.long()] = True
mat[double_bool_cover == True] = mat[double_bool_cover == True] + min_non_cover
return mat
def hungarian_algorithm(mat: Tensor) -> Tensor:
dim = mat.shape[0]
cur_mat = mat
cur_mat = cur_mat - cur_mat.min(1, keepdim = True)[0]
cur_mat = cur_mat - cur_mat.min(0, keepdim = True)[0]
zero_count = 0
while zero_count < dim:
ans_pos, marked_rows, marked_cols = mark_matrix(cur_mat)
zero_count = len(marked_rows) + len(marked_cols)
if zero_count < dim:
cur_mat = adjust_matrix(cur_mat, marked_rows, marked_cols)
return ans_pos
# Example 1
mat = torch.tensor(
[[7, 6, 2, 9, 2],
[6, 2, 1, 3, 9],
[5, 6, 8, 9, 5],
[6, 8, 5, 8, 6],
[9, 5, 6, 4, 7]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 2
mat = torch.tensor(
[[108, 125, 150],
[150, 135, 175],
[122, 148, 250]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 3
mat = torch.tensor(
[[1500, 4000, 4500],
[2000, 6000, 3500],
[2000, 4000, 2500]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 4
mat = torch.tensor(
[[5, 9, 3, 6],
[8, 7, 8, 2],
[6, 10, 12, 7],
[3, 10, 8, 6]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
@1601110116
Copy link

Dear authror:
Thanks for your excellent programming, which has been very helpful to me. If it is helpful to you, I'd like to offer a suggestion.
I'm afraid that the algorithm you adapted from "https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15" was incorrect. Specifically, the function "adjust_matrix()" is expected to modify the matrix by adding and subtracting "min_non_cover", so that "zero_count" could reach "dim". However, the function "mark_matrix()" does not make sure no zero is uncovered, since "non_marked_row" may contain zeros that does not lie in "marked_cols". In this case, adjust_matrix() is doing nothing so that the loop in "hungarian algorithm" is dead.
The algorithm in https://brc2.com/the-algorithm-workshop/ should be correct since I have not found any contradictory in it. I'd be very happy to see you update your code.
If you have any ideas, my email is pb12013004@outlook.com

@ivanstepanovftw
Copy link

I have extracted ultralytics and DETR code, passed 1000 lines to GPT4 to refactor it. So, here is my PyTorch implementation of Hungarian loss function with SciPy assignment problem solver in 12 lines of code.

@LivesayMe
Copy link

@ivanstepanovftw Scipy.optimize's implementation of linear sum assignment requires a numpy array. Moving the iou tensor to cpu to get it as a numpy array will detach it from the computation graph, meaning you can't calculate grads on it anymore.

@ivanstepanovftw
Copy link

@LivesayMe, while this is true, linear sum assignment only returns indices that are required for Hungarian loss to be used as a mask to permute indices in predicted and target tensors. Take a look at minimal implementation of my Hungarian loss implementation on Github Gist:

def hungarian_loss(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().detach().numpy())
    matched_outputs = outputs[row_ind]
    matched_targets = targets[col_ind]
    loss = F.l1_loss(matched_outputs, matched_targets)
    return loss

Here you see that matched_outputs located at the same device as outputs, and matched_targets are located at the same device as targets, while they are being permuted by row_ind and col_ind of type numpy.ndarray.

Note that in this hungarian_loss implementation requires both outputs and targets to not be batched, i.e. tensor dims are (set_length, set_features). For batched implementation you need to adapt the code for batching, i.e.:

def criterion(x, y, x_lengths, y_lengths):
    hungarian = torch.tensor(0.0)
    for i in range(x.shape[0]):
        hungarian += hungarian_loss(x[i, :x_lengths[i]], y[i, :y_lengths[i]])
    hungarian /= x.shape[0]  # batchmean
    return hungarian

If you still not sure if it works in product, learn how Ultralytics YOLO code works, and even uses alternative solvers, such as lap.lapjv from https://github.com/gatagat/lap. Search for repo:ultralytics/ultralytics linear_sum_assignment.

In case if you want alternatives to Hungarian loss, take a look into Chamfer distance written in both CPU and CUDA code.

from pytorch3d.loss import chamfer_distance

def criterion(x, y, x_lengths, y_lengths):
    cham, cham_norm = chamfer_distance(x, y, x_lengths, y_lengths, point_reduction='mean', single_directional=False, abs_cosine=True)
    return cham

If you have any questions feel free to ask, I will happy to answer it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment