Last active
January 8, 2023 20:25
-
-
Save atomkirk/a4ac4c3d6ef964eaab4b7f55ef045f83 to your computer and use it in GitHub Desktop.
Hungarian/Munkres algorithm in Elixir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Hungarian do | |
@moduledoc """ | |
Written by Adam Kirk – Jan 18, 2020 | |
Most helpful resources used: | |
https://www.youtube.com/watch?v=dQDZNHwuuOY | |
https://www.youtube.com/watch?v=cQ5MsiGaDY8 | |
https://www.geeksforgeeks.org/hungarian-algorithm-assignment-problem-set-1-introduction/ | |
""" | |
# takes an nxn matrix of costs and returns a list of {row, column} | |
# tuples of assigments that minimizes total cost | |
def compute([row1 | _] = matrix) do | |
matrix | |
|> IO.inspect(label: "hungarian_input_matrix", limit: :infinity, printable_limit: :infinity, pretty: true) | |
# add "zero" rows if its not a square matrix | |
|> pad() | |
# perform the calculation | |
|> step() | |
# remove any assignments that are in the padded matrix | |
|> Enum.filter(fn {r, c} -> r < length(matrix) and c < length(row1) end) | |
end | |
defp step(matrix, step \\ 1, assignments \\ nil, count \\ 0) | |
# match on done | |
defp step(matrix, _step, assignments, _count) when length(assignments) == length(matrix), do: assignments | |
# For each row of the matrix, find the smallest element and | |
# subtract it from every element in its row. If no assignments, go step 2 | |
defp step(matrix, 1, _assignments, _count) do | |
transformed = rows_to_zero(matrix) | |
assigned = assignments(transformed) | |
step(transformed, 2, assigned) | |
end | |
# For each column of the matrix, find the smallest element and | |
# subtract it from every element in its column. If no assignments, go step 3 | |
defp step(matrix, 2, _assignments, _count) do | |
transformed = | |
matrix | |
|> transpose() | |
|> rows_to_zero() | |
|> transpose() | |
assigned = assignments(transformed) | |
step(transformed, 3, assigned) | |
end | |
defp step(matrix, 3, _assignments, count) do | |
{covered_rows, covered_cols} = min_lines(matrix) | |
IO.inspect("#{Enum.join(covered_rows, ",")} x #{Enum.join(covered_cols, ",")}") | |
min_uncovered = | |
matrix | |
|> transform(fn {r, c}, val -> | |
if c not in covered_cols and r not in covered_rows do | |
val | |
end | |
end) | |
|> List.flatten() | |
|> Enum.filter(&(!is_nil(&1))) | |
|> Enum.min() | |
|> IO.inspect(label: "min_uncovered") | |
transformed = | |
matrix | |
|> transform(fn {r, c}, val -> | |
case {r in covered_rows, c in covered_cols} do | |
# if uncovered, subtract the min | |
{false, false} -> Float.round(val - min_uncovered, 3) | |
# if covered by a vertical and horizontal line, add min_uncovered | |
{true, true} -> Float.round(val + min_uncovered, 3) | |
# otherwise, leave it alone | |
_ -> val | |
end | |
end) | |
|> print_matrix() | |
assigned = assignments(transformed) | |
if count < 50 do | |
step(transformed, 3, assigned, count + 1) | |
else | |
raise "There must be a bug in this code that can't handle the input matrix." | |
end | |
end | |
defp assignments(matrix) do | |
matrix | |
|> reduce([], fn {r, c} = coord, val, acc -> | |
if val == 0 do | |
h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0)) | |
v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0)) | |
[{coord, h_zeros + v_zeros} | acc] | |
else | |
acc | |
end | |
end) | |
|> Enum.sort_by(fn {_, zero_count} -> zero_count end) | |
|> Enum.reduce([], fn {{r, c} = coord, _}, acc -> | |
{assigned_rows, assigned_cols} = Enum.unzip(acc) | |
if r not in assigned_rows && c not in assigned_cols do | |
[coord | acc] | |
else | |
acc | |
end | |
end) | |
|> IO.inspect() | |
end | |
# https://stackoverflow.com/questions/23379660/hungarian-algorithm-finding-minimum-number-of-lines-to-cover-zeroes | |
defp min_lines(matrix) do | |
matrix | |
# Calculate the max number of zeros vertically vs horizontally for each xy position in the input matrix | |
# and store the result in a separate array called m2. | |
# While calculating, if horizontal zeros > vertical zeroes, then the calculated number is converted | |
# to negative. (just to distinguish which direction we chose for later use) | |
|> transform(fn {r, c}, val -> | |
h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0)) | |
v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0)) | |
cond do | |
val != 0 -> 0 | |
h_zeros > v_zeros -> -h_zeros | |
true -> v_zeros | |
end | |
end) | |
# Loop through all elements in the m2 array. If the value is positive, draw a vertical line in array m3, | |
# if value is negative, draw an horizontal line in m3 | |
|> reduce({[], []}, fn | |
{_, c}, val, {rows, cols} when val > 0 -> {rows, [c | cols] |> Enum.uniq()} | |
{r, _}, val, {rows, cols} when val < 0 -> {[r | rows] |> Enum.uniq(), cols} | |
_, _, acc -> acc | |
end) | |
end | |
defp rows_to_zero(matrix) do | |
Enum.map(matrix, fn row -> | |
min = Enum.min(row) | |
Enum.map(row, fn column -> | |
Float.round(column - min, 3) | |
end) | |
end) | |
end | |
defp transpose(matrix) do | |
transform(matrix, fn {r, c}, _ -> matrix |> Enum.at(c) |> Enum.at(r) end) | |
end | |
defp transform(matrix, func) do | |
matrix | |
|> Enum.with_index() | |
|> Enum.map(fn {row, r} -> | |
row | |
|> Enum.with_index() | |
|> Enum.map(fn {_column, c} -> | |
func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c)) | |
end) | |
end) | |
end | |
def reduce(matrix, init, func) do | |
matrix | |
|> Enum.with_index() | |
|> Enum.reduce(init, fn {row, r}, acc -> | |
row | |
|> Enum.with_index() | |
|> Enum.reduce(acc, fn {_column, c}, acc2 -> | |
func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c), acc2) | |
end) | |
end) | |
end | |
defp print_matrix(matrix, opts \\ []) do | |
IO.puts("#{opts[:label]}------------") | |
for row <- matrix do | |
row | |
|> Enum.map(fn v -> truncate(v) end) | |
|> Enum.join("\t") | |
|> IO.puts() | |
end | |
matrix | |
end | |
defp row(matrix, index), do: Enum.at(matrix, index) | |
defp column(matrix, index), do: Enum.map(matrix, &Enum.at(&1, index)) | |
defp pad([first | _] = matrix) do | |
case length(matrix) - length(first) do | |
# use the matrix only if it has the same number of columns and rows | |
0 -> | |
matrix | |
# more rows than columns, add zero columns to each row | |
diff when diff > 0 -> | |
Enum.map(matrix, fn row -> | |
row ++ Enum.map(1..abs(diff), fn _ -> 0 end) | |
end) | |
# more columns than rows, add a row of zeros | |
diff when diff < 0 -> | |
matrix ++ [Enum.map(1..length(matrix), fn _ -> 0 end)] | |
end | |
end | |
defp pad(matrix), do: matrix | |
defp truncate(float) do | |
trunc(float * 1000) / 1000 | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule HungarianTest do | |
use ExUnit.Case, async: true | |
# some taken from https://github.com/addaleax/munkres-js/blob/master/test/test.js | |
test "handles singleton matrix" do | |
assert [{0, 0}] = Hungarian.compute([[5]]) | |
end | |
test "handles negative singleton matrix" do | |
assert [{0, 0}] = Hungarian.compute([[-5]]) | |
end | |
test "handles 2-by-2 matrix" do | |
assert [{1, 0}, {0, 1}] = Hungarian.compute([[5, 3], [2, 4]]) | |
end | |
test "handles 2-by-2 negative matrix" do | |
assert [{1, 1}, {0, 0}] = Hungarian.compute([[-5, -3], [-2, -4]]) | |
end | |
test "3-by-3 that is solved by step 1" do | |
data = [ | |
[401, 150, 405], | |
[402, 450, 600], | |
[305, 300, 225] | |
] | |
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "3-by-3 that is solved by step 2" do | |
data = [ | |
[400, 150, 405], | |
[402, 450, 600], | |
[305, 225, 300] | |
] | |
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "3-by-3 that is solved by step 3" do | |
data = [ | |
[5, 3, -1], | |
[2, 4, -6], | |
[9, 9, -9] | |
] | |
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "handles 3-by-3 matrix" do | |
data = [ | |
[5, 3, 1], | |
[2, 4, 6], | |
[9, 9, 9] | |
] | |
assert [{2, 1}, {1, 0}, {0, 2}] = Hungarian.compute(data) | |
end | |
test "handles another 3-by-3 matrix" do | |
data = [ | |
[400, 150, 400], | |
[400, 450, 600], | |
[300, 225, 300] | |
] | |
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "handles all-zero 3-by-3 matrix" do | |
data = [ | |
[0, 0, 0], | |
[0, 0, 0], | |
[0, 0, 0] | |
] | |
assert [{2, 2}, {1, 1}, {0, 0}] = Hungarian.compute(data) | |
end | |
test "handles rectangular 3-by-4 matrix" do | |
data = [ | |
[400, 150, 400, 1], | |
[400, 450, 600, 2], | |
[300, 225, 300, 3] | |
] | |
assert [{2, 0}, {1, 3}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "handles rectangular 3-by-5 matrix" do | |
data = [ | |
[400, 150, 400], | |
[400, 450, 600], | |
[300, 225, 300], | |
[1, 2, 3], | |
[4, 5, 6] | |
] | |
assert [{4, 2}, {3, 0}, {0, 1}] = Hungarian.compute(data) | |
end | |
test "4-by-4 matrix" do | |
data = [ | |
[80, 40, 50, 46], | |
[40, 70, 20, 25], | |
[30, 10, 20, 30], | |
[35, 20, 25, 30] | |
] | |
assert [{3, 0}, {2, 1}, {1, 2}, {0, 3}] = Hungarian.compute(data) | |
end | |
test "10-by-10 matrix" do | |
data = [ | |
[0.212245, 3.97026, 4.35294, 4.68036, 3.9318, 5.39075, 4.04685, 4.57709, 4.71826, 4.61443], | |
[4.10203, 3.62481, 4.32934, 2.88815, 3.6587, 6.37339, 4.85513, 5.43085, 5.64526, 5.54708], | |
[3.99922, 2.8268, 0.967688, 3.86954, 3.36146, 6.27778, 5.03768, 5.31432, 5.53277, 5.45011], | |
[4.6018, 4.49099, 4.11193, 0.673058, 3.59617, 6.61862, 5.5957, 5.73652, 5.99593, 5.89837], | |
[3.82977, 3.72414, 3.61071, 3.55919, 0.498291, 5.98384, 4.79049, 4.99767, 5.21171, 5.10687], | |
[5.39791, 6.44015, 6.6127, 6.70207, 6.09563, 0.0215472, 6.14398, 6.10082, 6.27091, 6.16451], | |
[4.04463, 4.78672, 5.35553, 5.66122, 4.88248, 6.11385, 0.45157, 5.10391, 5.37808, 5.27629], | |
[4.07994, 4.9917, 5.16107, 5.32683, 4.60849, 5.58742, 4.63669, 0.560706, 4.48667, 4.39025], | |
[4.72601, 5.73108, 5.88376, 6.08028, 5.33076, 6.27062, 5.40051, 4.9789, 0.0220178, 3.88307], | |
[4.62512, 5.61403, 5.80195, 5.98378, 5.23184, 6.16688, 5.30455, 4.89557, 3.88567, 0.0236218] | |
] | |
assert [ | |
{9, 9}, | |
{8, 8}, | |
{7, 7}, | |
{6, 6}, | |
{5, 5}, | |
{4, 4}, | |
{3, 3}, | |
{2, 2}, | |
{1, 1}, | |
{0, 0} | |
] = Hungarian.compute(data) | |
end | |
test "test rounds floats during computation" do | |
data = [ | |
[5.73652, 4.99767, 6.10082, 5.10391, 0.560706, 4.9789, 4.89557, 4.57709, 5.43085, 5.31432], | |
[5.21171, 6.27091, 5.37808, 4.48667, 0.0220178, 3.88567, 4.71826, 5.64526, 5.53277, 5.99593], | |
[4.39025, 3.88307, 0.0236218, 4.61443, 5.54708, 5.45011, 5.89837, 5.10687, 6.16451, 5.27629], | |
[5.39791, 4.04463, 4.07994, 4.72601, 4.62512, 0.212245, 4.10203, 3.99922, 4.6018, 3.82977], | |
[3.97026, 3.62481, 2.8268, 4.49099, 3.72414, 6.44015, 4.78672, 4.9917, 5.73108, 5.61403], | |
[5.16107, 5.88376, 5.80195, 4.35294, 4.32934, 0.967688, 4.11193, 3.61071, 6.6127, 5.35553], | |
[2.88815, 3.86954, 0.673058, 3.55919, 6.70207, 5.66122, 5.32683, 6.08028, 5.98378, 4.68036], | |
[6.09563, 4.88248, 4.60849, 5.33076, 5.23184, 3.9318, 3.6587, 3.36146, 3.59617, 0.498291], | |
[5.39075, 6.37339, 6.16688, 6.27778, 6.61862, 5.98384, 0.0215472, 6.11385, 5.58742, 6.27062], | |
[4.04685, 4.85513, 5.30455, 5.03768, 5.5957, 4.79049, 6.14398, 0.45157, 4.63669, 5.40051] | |
] | |
assert [ | |
{0, 8}, | |
{1, 4}, | |
{2, 2}, | |
{3, 5}, | |
{4, 1}, | |
{5, 3}, | |
{6, 0}, | |
{7, 9}, | |
{8, 6}, | |
{9, 7} | |
] = Hungarian.compute(data) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment