Skip to content

Instantly share code, notes, and snippets.

@atomkirk
Last active January 8, 2023 20:25
Show Gist options
  • Save atomkirk/a4ac4c3d6ef964eaab4b7f55ef045f83 to your computer and use it in GitHub Desktop.
Save atomkirk/a4ac4c3d6ef964eaab4b7f55ef045f83 to your computer and use it in GitHub Desktop.
Hungarian/Munkres algorithm in Elixir
defmodule Hungarian do
@moduledoc """
Written by Adam Kirk – Jan 18, 2020
Most helpful resources used:
https://www.youtube.com/watch?v=dQDZNHwuuOY
https://www.youtube.com/watch?v=cQ5MsiGaDY8
https://www.geeksforgeeks.org/hungarian-algorithm-assignment-problem-set-1-introduction/
"""
# takes an nxn matrix of costs and returns a list of {row, column}
# tuples of assigments that minimizes total cost
def compute([row1 | _] = matrix) do
matrix
|> IO.inspect(label: "hungarian_input_matrix", limit: :infinity, printable_limit: :infinity, pretty: true)
# add "zero" rows if its not a square matrix
|> pad()
# perform the calculation
|> step()
# remove any assignments that are in the padded matrix
|> Enum.filter(fn {r, c} -> r < length(matrix) and c < length(row1) end)
end
defp step(matrix, step \\ 1, assignments \\ nil, count \\ 0)
# match on done
defp step(matrix, _step, assignments, _count) when length(assignments) == length(matrix), do: assignments
# For each row of the matrix, find the smallest element and
# subtract it from every element in its row. If no assignments, go step 2
defp step(matrix, 1, _assignments, _count) do
transformed = rows_to_zero(matrix)
assigned = assignments(transformed)
step(transformed, 2, assigned)
end
# For each column of the matrix, find the smallest element and
# subtract it from every element in its column. If no assignments, go step 3
defp step(matrix, 2, _assignments, _count) do
transformed =
matrix
|> transpose()
|> rows_to_zero()
|> transpose()
assigned = assignments(transformed)
step(transformed, 3, assigned)
end
defp step(matrix, 3, _assignments, count) do
{covered_rows, covered_cols} = min_lines(matrix)
IO.inspect("#{Enum.join(covered_rows, ",")} x #{Enum.join(covered_cols, ",")}")
min_uncovered =
matrix
|> transform(fn {r, c}, val ->
if c not in covered_cols and r not in covered_rows do
val
end
end)
|> List.flatten()
|> Enum.filter(&(!is_nil(&1)))
|> Enum.min()
|> IO.inspect(label: "min_uncovered")
transformed =
matrix
|> transform(fn {r, c}, val ->
case {r in covered_rows, c in covered_cols} do
# if uncovered, subtract the min
{false, false} -> Float.round(val - min_uncovered, 3)
# if covered by a vertical and horizontal line, add min_uncovered
{true, true} -> Float.round(val + min_uncovered, 3)
# otherwise, leave it alone
_ -> val
end
end)
|> print_matrix()
assigned = assignments(transformed)
if count < 50 do
step(transformed, 3, assigned, count + 1)
else
raise "There must be a bug in this code that can't handle the input matrix."
end
end
defp assignments(matrix) do
matrix
|> reduce([], fn {r, c} = coord, val, acc ->
if val == 0 do
h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0))
v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0))
[{coord, h_zeros + v_zeros} | acc]
else
acc
end
end)
|> Enum.sort_by(fn {_, zero_count} -> zero_count end)
|> Enum.reduce([], fn {{r, c} = coord, _}, acc ->
{assigned_rows, assigned_cols} = Enum.unzip(acc)
if r not in assigned_rows && c not in assigned_cols do
[coord | acc]
else
acc
end
end)
|> IO.inspect()
end
# https://stackoverflow.com/questions/23379660/hungarian-algorithm-finding-minimum-number-of-lines-to-cover-zeroes
defp min_lines(matrix) do
matrix
# Calculate the max number of zeros vertically vs horizontally for each xy position in the input matrix
# and store the result in a separate array called m2.
# While calculating, if horizontal zeros > vertical zeroes, then the calculated number is converted
# to negative. (just to distinguish which direction we chose for later use)
|> transform(fn {r, c}, val ->
h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0))
v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0))
cond do
val != 0 -> 0
h_zeros > v_zeros -> -h_zeros
true -> v_zeros
end
end)
# Loop through all elements in the m2 array. If the value is positive, draw a vertical line in array m3,
# if value is negative, draw an horizontal line in m3
|> reduce({[], []}, fn
{_, c}, val, {rows, cols} when val > 0 -> {rows, [c | cols] |> Enum.uniq()}
{r, _}, val, {rows, cols} when val < 0 -> {[r | rows] |> Enum.uniq(), cols}
_, _, acc -> acc
end)
end
defp rows_to_zero(matrix) do
Enum.map(matrix, fn row ->
min = Enum.min(row)
Enum.map(row, fn column ->
Float.round(column - min, 3)
end)
end)
end
defp transpose(matrix) do
transform(matrix, fn {r, c}, _ -> matrix |> Enum.at(c) |> Enum.at(r) end)
end
defp transform(matrix, func) do
matrix
|> Enum.with_index()
|> Enum.map(fn {row, r} ->
row
|> Enum.with_index()
|> Enum.map(fn {_column, c} ->
func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c))
end)
end)
end
def reduce(matrix, init, func) do
matrix
|> Enum.with_index()
|> Enum.reduce(init, fn {row, r}, acc ->
row
|> Enum.with_index()
|> Enum.reduce(acc, fn {_column, c}, acc2 ->
func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c), acc2)
end)
end)
end
defp print_matrix(matrix, opts \\ []) do
IO.puts("#{opts[:label]}------------")
for row <- matrix do
row
|> Enum.map(fn v -> truncate(v) end)
|> Enum.join("\t")
|> IO.puts()
end
matrix
end
defp row(matrix, index), do: Enum.at(matrix, index)
defp column(matrix, index), do: Enum.map(matrix, &Enum.at(&1, index))
defp pad([first | _] = matrix) do
case length(matrix) - length(first) do
# use the matrix only if it has the same number of columns and rows
0 ->
matrix
# more rows than columns, add zero columns to each row
diff when diff > 0 ->
Enum.map(matrix, fn row ->
row ++ Enum.map(1..abs(diff), fn _ -> 0 end)
end)
# more columns than rows, add a row of zeros
diff when diff < 0 ->
matrix ++ [Enum.map(1..length(matrix), fn _ -> 0 end)]
end
end
defp pad(matrix), do: matrix
defp truncate(float) do
trunc(float * 1000) / 1000
end
end
defmodule HungarianTest do
use ExUnit.Case, async: true
# some taken from https://github.com/addaleax/munkres-js/blob/master/test/test.js
test "handles singleton matrix" do
assert [{0, 0}] = Hungarian.compute([[5]])
end
test "handles negative singleton matrix" do
assert [{0, 0}] = Hungarian.compute([[-5]])
end
test "handles 2-by-2 matrix" do
assert [{1, 0}, {0, 1}] = Hungarian.compute([[5, 3], [2, 4]])
end
test "handles 2-by-2 negative matrix" do
assert [{1, 1}, {0, 0}] = Hungarian.compute([[-5, -3], [-2, -4]])
end
test "3-by-3 that is solved by step 1" do
data = [
[401, 150, 405],
[402, 450, 600],
[305, 300, 225]
]
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data)
end
test "3-by-3 that is solved by step 2" do
data = [
[400, 150, 405],
[402, 450, 600],
[305, 225, 300]
]
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data)
end
test "3-by-3 that is solved by step 3" do
data = [
[5, 3, -1],
[2, 4, -6],
[9, 9, -9]
]
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data)
end
test "handles 3-by-3 matrix" do
data = [
[5, 3, 1],
[2, 4, 6],
[9, 9, 9]
]
assert [{2, 1}, {1, 0}, {0, 2}] = Hungarian.compute(data)
end
test "handles another 3-by-3 matrix" do
data = [
[400, 150, 400],
[400, 450, 600],
[300, 225, 300]
]
assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data)
end
test "handles all-zero 3-by-3 matrix" do
data = [
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]
]
assert [{2, 2}, {1, 1}, {0, 0}] = Hungarian.compute(data)
end
test "handles rectangular 3-by-4 matrix" do
data = [
[400, 150, 400, 1],
[400, 450, 600, 2],
[300, 225, 300, 3]
]
assert [{2, 0}, {1, 3}, {0, 1}] = Hungarian.compute(data)
end
test "handles rectangular 3-by-5 matrix" do
data = [
[400, 150, 400],
[400, 450, 600],
[300, 225, 300],
[1, 2, 3],
[4, 5, 6]
]
assert [{4, 2}, {3, 0}, {0, 1}] = Hungarian.compute(data)
end
test "4-by-4 matrix" do
data = [
[80, 40, 50, 46],
[40, 70, 20, 25],
[30, 10, 20, 30],
[35, 20, 25, 30]
]
assert [{3, 0}, {2, 1}, {1, 2}, {0, 3}] = Hungarian.compute(data)
end
test "10-by-10 matrix" do
data = [
[0.212245, 3.97026, 4.35294, 4.68036, 3.9318, 5.39075, 4.04685, 4.57709, 4.71826, 4.61443],
[4.10203, 3.62481, 4.32934, 2.88815, 3.6587, 6.37339, 4.85513, 5.43085, 5.64526, 5.54708],
[3.99922, 2.8268, 0.967688, 3.86954, 3.36146, 6.27778, 5.03768, 5.31432, 5.53277, 5.45011],
[4.6018, 4.49099, 4.11193, 0.673058, 3.59617, 6.61862, 5.5957, 5.73652, 5.99593, 5.89837],
[3.82977, 3.72414, 3.61071, 3.55919, 0.498291, 5.98384, 4.79049, 4.99767, 5.21171, 5.10687],
[5.39791, 6.44015, 6.6127, 6.70207, 6.09563, 0.0215472, 6.14398, 6.10082, 6.27091, 6.16451],
[4.04463, 4.78672, 5.35553, 5.66122, 4.88248, 6.11385, 0.45157, 5.10391, 5.37808, 5.27629],
[4.07994, 4.9917, 5.16107, 5.32683, 4.60849, 5.58742, 4.63669, 0.560706, 4.48667, 4.39025],
[4.72601, 5.73108, 5.88376, 6.08028, 5.33076, 6.27062, 5.40051, 4.9789, 0.0220178, 3.88307],
[4.62512, 5.61403, 5.80195, 5.98378, 5.23184, 6.16688, 5.30455, 4.89557, 3.88567, 0.0236218]
]
assert [
{9, 9},
{8, 8},
{7, 7},
{6, 6},
{5, 5},
{4, 4},
{3, 3},
{2, 2},
{1, 1},
{0, 0}
] = Hungarian.compute(data)
end
test "test rounds floats during computation" do
data = [
[5.73652, 4.99767, 6.10082, 5.10391, 0.560706, 4.9789, 4.89557, 4.57709, 5.43085, 5.31432],
[5.21171, 6.27091, 5.37808, 4.48667, 0.0220178, 3.88567, 4.71826, 5.64526, 5.53277, 5.99593],
[4.39025, 3.88307, 0.0236218, 4.61443, 5.54708, 5.45011, 5.89837, 5.10687, 6.16451, 5.27629],
[5.39791, 4.04463, 4.07994, 4.72601, 4.62512, 0.212245, 4.10203, 3.99922, 4.6018, 3.82977],
[3.97026, 3.62481, 2.8268, 4.49099, 3.72414, 6.44015, 4.78672, 4.9917, 5.73108, 5.61403],
[5.16107, 5.88376, 5.80195, 4.35294, 4.32934, 0.967688, 4.11193, 3.61071, 6.6127, 5.35553],
[2.88815, 3.86954, 0.673058, 3.55919, 6.70207, 5.66122, 5.32683, 6.08028, 5.98378, 4.68036],
[6.09563, 4.88248, 4.60849, 5.33076, 5.23184, 3.9318, 3.6587, 3.36146, 3.59617, 0.498291],
[5.39075, 6.37339, 6.16688, 6.27778, 6.61862, 5.98384, 0.0215472, 6.11385, 5.58742, 6.27062],
[4.04685, 4.85513, 5.30455, 5.03768, 5.5957, 4.79049, 6.14398, 0.45157, 4.63669, 5.40051]
]
assert [
{0, 8},
{1, 4},
{2, 2},
{3, 5},
{4, 1},
{5, 3},
{6, 0},
{7, 9},
{8, 6},
{9, 7}
] = Hungarian.compute(data)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment