Instantly share code, notes, and snippets.

atomkirk/hungarian.ex

Last active January 8, 2023 20:25
Show Gist options
• Save atomkirk/a4ac4c3d6ef964eaab4b7f55ef045f83 to your computer and use it in GitHub Desktop.
Hungarian/Munkres algorithm in Elixir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 defmodule Hungarian do @moduledoc """ Written by Adam Kirk – Jan 18, 2020 Most helpful resources used: https://www.youtube.com/watch?v=dQDZNHwuuOY https://www.youtube.com/watch?v=cQ5MsiGaDY8 https://www.geeksforgeeks.org/hungarian-algorithm-assignment-problem-set-1-introduction/ """ # takes an nxn matrix of costs and returns a list of {row, column} # tuples of assigments that minimizes total cost def compute([row1 | _] = matrix) do matrix |> IO.inspect(label: "hungarian_input_matrix", limit: :infinity, printable_limit: :infinity, pretty: true) # add "zero" rows if its not a square matrix |> pad() # perform the calculation |> step() # remove any assignments that are in the padded matrix |> Enum.filter(fn {r, c} -> r < length(matrix) and c < length(row1) end) end defp step(matrix, step \\ 1, assignments \\ nil, count \\ 0) # match on done defp step(matrix, _step, assignments, _count) when length(assignments) == length(matrix), do: assignments # For each row of the matrix, find the smallest element and # subtract it from every element in its row. If no assignments, go step 2 defp step(matrix, 1, _assignments, _count) do transformed = rows_to_zero(matrix) assigned = assignments(transformed) step(transformed, 2, assigned) end # For each column of the matrix, find the smallest element and # subtract it from every element in its column. If no assignments, go step 3 defp step(matrix, 2, _assignments, _count) do transformed = matrix |> transpose() |> rows_to_zero() |> transpose() assigned = assignments(transformed) step(transformed, 3, assigned) end defp step(matrix, 3, _assignments, count) do {covered_rows, covered_cols} = min_lines(matrix) IO.inspect("#{Enum.join(covered_rows, ",")} x #{Enum.join(covered_cols, ",")}") min_uncovered = matrix |> transform(fn {r, c}, val -> if c not in covered_cols and r not in covered_rows do val end end) |> List.flatten() |> Enum.filter(&(!is_nil(&1))) |> Enum.min() |> IO.inspect(label: "min_uncovered") transformed = matrix |> transform(fn {r, c}, val -> case {r in covered_rows, c in covered_cols} do # if uncovered, subtract the min {false, false} -> Float.round(val - min_uncovered, 3) # if covered by a vertical and horizontal line, add min_uncovered {true, true} -> Float.round(val + min_uncovered, 3) # otherwise, leave it alone _ -> val end end) |> print_matrix() assigned = assignments(transformed) if count < 50 do step(transformed, 3, assigned, count + 1) else raise "There must be a bug in this code that can't handle the input matrix." end end defp assignments(matrix) do matrix |> reduce([], fn {r, c} = coord, val, acc -> if val == 0 do h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0)) v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0)) [{coord, h_zeros + v_zeros} | acc] else acc end end) |> Enum.sort_by(fn {_, zero_count} -> zero_count end) |> Enum.reduce([], fn {{r, c} = coord, _}, acc -> {assigned_rows, assigned_cols} = Enum.unzip(acc) if r not in assigned_rows && c not in assigned_cols do [coord | acc] else acc end end) |> IO.inspect() end # https://stackoverflow.com/questions/23379660/hungarian-algorithm-finding-minimum-number-of-lines-to-cover-zeroes defp min_lines(matrix) do matrix # Calculate the max number of zeros vertically vs horizontally for each xy position in the input matrix # and store the result in a separate array called m2. # While calculating, if horizontal zeros > vertical zeroes, then the calculated number is converted # to negative. (just to distinguish which direction we chose for later use) |> transform(fn {r, c}, val -> h_zeros = row(matrix, r) |> Enum.count(&(&1 == 0)) v_zeros = column(matrix, c) |> Enum.count(&(&1 == 0)) cond do val != 0 -> 0 h_zeros > v_zeros -> -h_zeros true -> v_zeros end end) # Loop through all elements in the m2 array. If the value is positive, draw a vertical line in array m3, # if value is negative, draw an horizontal line in m3 |> reduce({[], []}, fn {_, c}, val, {rows, cols} when val > 0 -> {rows, [c | cols] |> Enum.uniq()} {r, _}, val, {rows, cols} when val < 0 -> {[r | rows] |> Enum.uniq(), cols} _, _, acc -> acc end) end defp rows_to_zero(matrix) do Enum.map(matrix, fn row -> min = Enum.min(row) Enum.map(row, fn column -> Float.round(column - min, 3) end) end) end defp transpose(matrix) do transform(matrix, fn {r, c}, _ -> matrix |> Enum.at(c) |> Enum.at(r) end) end defp transform(matrix, func) do matrix |> Enum.with_index() |> Enum.map(fn {row, r} -> row |> Enum.with_index() |> Enum.map(fn {_column, c} -> func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c)) end) end) end def reduce(matrix, init, func) do matrix |> Enum.with_index() |> Enum.reduce(init, fn {row, r}, acc -> row |> Enum.with_index() |> Enum.reduce(acc, fn {_column, c}, acc2 -> func.({r, c}, matrix |> Enum.at(r) |> Enum.at(c), acc2) end) end) end defp print_matrix(matrix, opts \\ []) do IO.puts("#{opts[:label]}------------") for row <- matrix do row |> Enum.map(fn v -> truncate(v) end) |> Enum.join("\t") |> IO.puts() end matrix end defp row(matrix, index), do: Enum.at(matrix, index) defp column(matrix, index), do: Enum.map(matrix, &Enum.at(&1, index)) defp pad([first | _] = matrix) do case length(matrix) - length(first) do # use the matrix only if it has the same number of columns and rows 0 -> matrix # more rows than columns, add zero columns to each row diff when diff > 0 -> Enum.map(matrix, fn row -> row ++ Enum.map(1..abs(diff), fn _ -> 0 end) end) # more columns than rows, add a row of zeros diff when diff < 0 -> matrix ++ [Enum.map(1..length(matrix), fn _ -> 0 end)] end end defp pad(matrix), do: matrix defp truncate(float) do trunc(float * 1000) / 1000 end end
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 defmodule HungarianTest do use ExUnit.Case, async: true # some taken from https://github.com/addaleax/munkres-js/blob/master/test/test.js test "handles singleton matrix" do assert [{0, 0}] = Hungarian.compute([[5]]) end test "handles negative singleton matrix" do assert [{0, 0}] = Hungarian.compute([[-5]]) end test "handles 2-by-2 matrix" do assert [{1, 0}, {0, 1}] = Hungarian.compute([[5, 3], [2, 4]]) end test "handles 2-by-2 negative matrix" do assert [{1, 1}, {0, 0}] = Hungarian.compute([[-5, -3], [-2, -4]]) end test "3-by-3 that is solved by step 1" do data = [ [401, 150, 405], [402, 450, 600], [305, 300, 225] ] assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) end test "3-by-3 that is solved by step 2" do data = [ [400, 150, 405], [402, 450, 600], [305, 225, 300] ] assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) end test "3-by-3 that is solved by step 3" do data = [ [5, 3, -1], [2, 4, -6], [9, 9, -9] ] assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) end test "handles 3-by-3 matrix" do data = [ [5, 3, 1], [2, 4, 6], [9, 9, 9] ] assert [{2, 1}, {1, 0}, {0, 2}] = Hungarian.compute(data) end test "handles another 3-by-3 matrix" do data = [ [400, 150, 400], [400, 450, 600], [300, 225, 300] ] assert [{2, 2}, {1, 0}, {0, 1}] = Hungarian.compute(data) end test "handles all-zero 3-by-3 matrix" do data = [ [0, 0, 0], [0, 0, 0], [0, 0, 0] ] assert [{2, 2}, {1, 1}, {0, 0}] = Hungarian.compute(data) end test "handles rectangular 3-by-4 matrix" do data = [ [400, 150, 400, 1], [400, 450, 600, 2], [300, 225, 300, 3] ] assert [{2, 0}, {1, 3}, {0, 1}] = Hungarian.compute(data) end test "handles rectangular 3-by-5 matrix" do data = [ [400, 150, 400], [400, 450, 600], [300, 225, 300], [1, 2, 3], [4, 5, 6] ] assert [{4, 2}, {3, 0}, {0, 1}] = Hungarian.compute(data) end test "4-by-4 matrix" do data = [ [80, 40, 50, 46], [40, 70, 20, 25], [30, 10, 20, 30], [35, 20, 25, 30] ] assert [{3, 0}, {2, 1}, {1, 2}, {0, 3}] = Hungarian.compute(data) end test "10-by-10 matrix" do data = [ [0.212245, 3.97026, 4.35294, 4.68036, 3.9318, 5.39075, 4.04685, 4.57709, 4.71826, 4.61443], [4.10203, 3.62481, 4.32934, 2.88815, 3.6587, 6.37339, 4.85513, 5.43085, 5.64526, 5.54708], [3.99922, 2.8268, 0.967688, 3.86954, 3.36146, 6.27778, 5.03768, 5.31432, 5.53277, 5.45011], [4.6018, 4.49099, 4.11193, 0.673058, 3.59617, 6.61862, 5.5957, 5.73652, 5.99593, 5.89837], [3.82977, 3.72414, 3.61071, 3.55919, 0.498291, 5.98384, 4.79049, 4.99767, 5.21171, 5.10687], [5.39791, 6.44015, 6.6127, 6.70207, 6.09563, 0.0215472, 6.14398, 6.10082, 6.27091, 6.16451], [4.04463, 4.78672, 5.35553, 5.66122, 4.88248, 6.11385, 0.45157, 5.10391, 5.37808, 5.27629], [4.07994, 4.9917, 5.16107, 5.32683, 4.60849, 5.58742, 4.63669, 0.560706, 4.48667, 4.39025], [4.72601, 5.73108, 5.88376, 6.08028, 5.33076, 6.27062, 5.40051, 4.9789, 0.0220178, 3.88307], [4.62512, 5.61403, 5.80195, 5.98378, 5.23184, 6.16688, 5.30455, 4.89557, 3.88567, 0.0236218] ] assert [ {9, 9}, {8, 8}, {7, 7}, {6, 6}, {5, 5}, {4, 4}, {3, 3}, {2, 2}, {1, 1}, {0, 0} ] = Hungarian.compute(data) end test "test rounds floats during computation" do data = [ [5.73652, 4.99767, 6.10082, 5.10391, 0.560706, 4.9789, 4.89557, 4.57709, 5.43085, 5.31432], [5.21171, 6.27091, 5.37808, 4.48667, 0.0220178, 3.88567, 4.71826, 5.64526, 5.53277, 5.99593], [4.39025, 3.88307, 0.0236218, 4.61443, 5.54708, 5.45011, 5.89837, 5.10687, 6.16451, 5.27629], [5.39791, 4.04463, 4.07994, 4.72601, 4.62512, 0.212245, 4.10203, 3.99922, 4.6018, 3.82977], [3.97026, 3.62481, 2.8268, 4.49099, 3.72414, 6.44015, 4.78672, 4.9917, 5.73108, 5.61403], [5.16107, 5.88376, 5.80195, 4.35294, 4.32934, 0.967688, 4.11193, 3.61071, 6.6127, 5.35553], [2.88815, 3.86954, 0.673058, 3.55919, 6.70207, 5.66122, 5.32683, 6.08028, 5.98378, 4.68036], [6.09563, 4.88248, 4.60849, 5.33076, 5.23184, 3.9318, 3.6587, 3.36146, 3.59617, 0.498291], [5.39075, 6.37339, 6.16688, 6.27778, 6.61862, 5.98384, 0.0215472, 6.11385, 5.58742, 6.27062], [4.04685, 4.85513, 5.30455, 5.03768, 5.5957, 4.79049, 6.14398, 0.45157, 4.63669, 5.40051] ] assert [ {0, 8}, {1, 4}, {2, 2}, {3, 5}, {4, 1}, {5, 3}, {6, 0}, {7, 9}, {8, 6}, {9, 7} ] = Hungarian.compute(data) end end