Skip to content

Instantly share code, notes, and snippets.

@wisnunugroho21
Last active April 2, 2024 17:22
Show Gist options
  • Save wisnunugroho21/e020b5fca1a93d5441bc7b4319e191cf to your computer and use it in GitHub Desktop.
Save wisnunugroho21/e020b5fca1a93d5441bc7b4319e191cf to your computer and use it in GitHub Desktop.
Pytorch implementation of Hungarian Algorithm
# Pytorch implementation of Hungarian Algorithm
# Inspired from here : https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15
# Despite my effort to parallelize the code, there is still some sequential workflows in this code
from typing import Tuple
import torch
from torch import Tensor
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def min_zero_row(zero_mat: Tensor) -> Tuple[Tensor, Tensor]:
sum_zero_mat = zero_mat.sum(1)
sum_zero_mat[sum_zero_mat == 0] = 9999
zero_row = sum_zero_mat.min(0)[1]
zero_column = zero_mat[zero_row].nonzero()[0]
zero_mat[zero_row, :] = False
zero_mat[:, zero_column] = False
mark_zero = torch.tensor([[zero_row, zero_column]], device = device)
return zero_mat, mark_zero
def mark_matrix(mat: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
zero_bool_mat = (mat == 0)
zero_bool_mat_copy = zero_bool_mat.clone()
marked_zero = torch.tensor([], device = device)
while (True in zero_bool_mat_copy):
zero_bool_mat_copy, mark_zero = min_zero_row(zero_bool_mat_copy)
marked_zero = torch.concat([marked_zero, mark_zero], dim = 0)
marked_zero_row = marked_zero[:, 0]
marked_zero_col = marked_zero[:, 1]
arange_index_row = torch.arange(mat.shape[0], dtype=torch.float, device = device).unsqueeze(1)
repeated_marked_row = marked_zero_row.repeat(mat.shape[0], 1)
bool_non_marked_row = torch.all(arange_index_row != repeated_marked_row, dim = 1)
non_marked_row = arange_index_row[bool_non_marked_row].squeeze()
non_marked_mat = zero_bool_mat[non_marked_row.long(), :]
marked_cols = non_marked_mat.nonzero().unique()
is_need_add_row = True
while is_need_add_row:
repeated_non_marked_row = non_marked_row.repeat(marked_zero_row.shape[0], 1)
repeated_marked_cols = marked_cols.repeat(marked_zero_col.shape[0], 1)
first_bool = torch.all(marked_zero_row.unsqueeze(1) != repeated_non_marked_row, dim = 1)
second_bool = torch.any(marked_zero_col.unsqueeze(1) == repeated_marked_cols, dim = 1)
addit_non_marked_row = marked_zero_row[first_bool & second_bool]
if addit_non_marked_row.shape[0] > 0:
non_marked_row = torch.concat([non_marked_row.reshape(-1), addit_non_marked_row[0].reshape(-1)])
else:
is_need_add_row = False
repeated_non_marked_row = non_marked_row.repeat(mat.shape[0], 1)
bool_marked_row = torch.all(arange_index_row != repeated_non_marked_row, dim = 1)
marked_rows = arange_index_row[bool_marked_row].squeeze(0)
return marked_zero, marked_rows, marked_cols
def adjust_matrix(mat: Tensor, cover_rows: Tensor, cover_cols: Tensor) -> Tensor:
bool_cover = torch.zeros_like(mat)
bool_cover[cover_rows.long()] = True
bool_cover[:, cover_cols.long()] = True
non_cover = mat[bool_cover != True]
min_non_cover = non_cover.min()
mat[bool_cover != True] = mat[bool_cover != True] - min_non_cover
double_bool_cover = torch.zeros_like(mat)
double_bool_cover[cover_rows.long(), cover_cols.long()] = True
mat[double_bool_cover == True] = mat[double_bool_cover == True] + min_non_cover
return mat
def hungarian_algorithm(mat: Tensor) -> Tensor:
dim = mat.shape[0]
cur_mat = mat
cur_mat = cur_mat - cur_mat.min(1, keepdim = True)[0]
cur_mat = cur_mat - cur_mat.min(0, keepdim = True)[0]
zero_count = 0
while zero_count < dim:
ans_pos, marked_rows, marked_cols = mark_matrix(cur_mat)
zero_count = len(marked_rows) + len(marked_cols)
if zero_count < dim:
cur_mat = adjust_matrix(cur_mat, marked_rows, marked_cols)
return ans_pos
# Example 1
mat = torch.tensor(
[[7, 6, 2, 9, 2],
[6, 2, 1, 3, 9],
[5, 6, 8, 9, 5],
[6, 8, 5, 8, 6],
[9, 5, 6, 4, 7]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 2
mat = torch.tensor(
[[108, 125, 150],
[150, 135, 175],
[122, 148, 250]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 3
mat = torch.tensor(
[[1500, 4000, 4500],
[2000, 6000, 3500],
[2000, 4000, 2500]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
print('==============')
# Example 4
mat = torch.tensor(
[[5, 9, 3, 6],
[8, 7, 8, 2],
[6, 10, 12, 7],
[3, 10, 8, 6]], device = device)
ans_pos = hungarian_algorithm(mat)
print(ans_pos)
res = mat[ans_pos[:, 0].long(), ans_pos[:, 1].long()]
print(res)
print(res.sum())
@ivanstepanovftw
Copy link

@LivesayMe, while this is true, linear sum assignment only returns indices that are required for Hungarian loss to be used as a mask to permute indices in predicted and target tensors. Take a look at minimal implementation of my Hungarian loss implementation on Github Gist:

def hungarian_loss(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().detach().numpy())
    matched_outputs = outputs[row_ind]
    matched_targets = targets[col_ind]
    loss = F.l1_loss(matched_outputs, matched_targets)
    return loss

Here you see that matched_outputs located at the same device as outputs, and matched_targets are located at the same device as targets, while they are being permuted by row_ind and col_ind of type numpy.ndarray.

Note that in this hungarian_loss implementation requires both outputs and targets to not be batched, i.e. tensor dims are (set_length, set_features). For batched implementation you need to adapt the code for batching, i.e.:

def criterion(x, y, x_lengths, y_lengths):
    hungarian = torch.tensor(0.0)
    for i in range(x.shape[0]):
        hungarian += hungarian_loss(x[i, :x_lengths[i]], y[i, :y_lengths[i]])
    hungarian /= x.shape[0]  # batchmean
    return hungarian

If you still not sure if it works in product, learn how Ultralytics YOLO code works, and even uses alternative solvers, such as lap.lapjv from https://github.com/gatagat/lap. Search for repo:ultralytics/ultralytics linear_sum_assignment.

In case if you want alternatives to Hungarian loss, take a look into Chamfer distance written in both CPU and CUDA code.

from pytorch3d.loss import chamfer_distance

def criterion(x, y, x_lengths, y_lengths):
    cham, cham_norm = chamfer_distance(x, y, x_lengths, y_lengths, point_reduction='mean', single_directional=False, abs_cosine=True)
    return cham

If you have any questions feel free to ask, I will happy to answer it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment