Last active
April 17, 2023 00:00
-
-
Save jvlmdr/cfdef3de27653bed0aec93bbfb41481c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import collections\n", | |
"\n", | |
"import lap\n", | |
"import lapjv\n", | |
"import lapsolver\n", | |
"import munkres\n", | |
"import numpy as np\n", | |
"import ortools.graph.pywrapgraph\n", | |
"import scipy.optimize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Problem = collections.namedtuple('Problem', ['costs', 'min_cost'])\n", | |
"\n", | |
"X = np.inf\n", | |
"problems = collections.OrderedDict()\n", | |
"problems['balanced_perfect'] = Problem(\n", | |
" np.array([[X, 1, 2],\n", | |
" [3, 5, X],\n", | |
" [4, X, X]]),\n", | |
" 2 + 5 + 4,\n", | |
")\n", | |
"problems['balanced_imperfect'] = Problem(\n", | |
" np.array([[X, 1, 2],\n", | |
" [3, X, X],\n", | |
" [4, X, X]]),\n", | |
" 1 + 3,\n", | |
")\n", | |
"problems['unbalanced_onesided'] = Problem(\n", | |
" np.array([[X, 2, 1],\n", | |
" [-1, X, X]]),\n", | |
" -1 + 1,\n", | |
")\n", | |
"problems['unbalanced_imperfect'] = Problem(\n", | |
" np.array([[X, 2, 1, 3],\n", | |
" [-1, X, X, X],\n", | |
" [0, X, X, X]]),\n", | |
" -1 + 1,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def solve_lap(costs):\n", | |
" is_square = (costs.shape[0] == costs.shape[1])\n", | |
" min_cost, row_to_col, col_to_row = lap.lapjv(costs, extend_cost=(not is_square))\n", | |
" return min_cost, make_matching(enumerate(c for c in row_to_col if c >= 0))\n", | |
"\n", | |
"\n", | |
"def solve_lapmod(costs):\n", | |
" raise NotImplementedError()\n", | |
"\n", | |
"\n", | |
"def solve_lapjv(costs):\n", | |
" costs = np.where(costs < np.inf, costs, np.inf)\n", | |
" row_to_col, col_to_row, (min_cost, _, _) = lapjv.lapjv(costs)\n", | |
" return min_cost, make_matching(enumerate(c for c in row_to_col if c >= 0))\n", | |
"\n", | |
"\n", | |
"def solve_lapsolver(costs):\n", | |
" rows, cols = lapsolver.solve_dense(costs)\n", | |
" return None, make_matching(zip(rows, cols))\n", | |
"\n", | |
"\n", | |
"def solve_munkres(costs):\n", | |
" costs = [\n", | |
" [(x if x < np.inf else munkres.DISALLOWED) for x in row]\n", | |
" for row in costs\n", | |
" ]\n", | |
" m = munkres.Munkres()\n", | |
" try:\n", | |
" pairs = m.compute(costs)\n", | |
" except munkres.UnsolvableMatrix as ex:\n", | |
" raise ValueError(ex)\n", | |
" return make_matching(pairs)\n", | |
"\n", | |
"\n", | |
"def solve_ortools(costs):\n", | |
" assignment = ortools.graph.pywrapgraph.LinearSumAssignment()\n", | |
" for (i, j), cost in np.ndenumerate(costs):\n", | |
" if cost < np.inf:\n", | |
" assert cost == int(cost)\n", | |
" assignment.AddArcWithCost(i, j, int(cost))\n", | |
" solve_status = assignment.Solve()\n", | |
" if solve_status == assignment.INFEASIBLE:\n", | |
" raise ValueError('infeasible')\n", | |
" if solve_status != assignment.OPTIMAL:\n", | |
" raise ValueError('not optimal', solve_status)\n", | |
" n = assignment.NumNodes()\n", | |
" min_cost = assignment.OptimalCost()\n", | |
" pairs = [(i, assignment.RightMate(i)) for i in range(n)]\n", | |
" return min_cost, make_matching(pairs)\n", | |
"\n", | |
"\n", | |
"def solve_scipy(costs):\n", | |
" costs = np.where(costs < np.inf, costs, np.inf)\n", | |
" rows, cols = scipy.optimize.linear_sum_assignment(costs)\n", | |
" return None, make_matching(zip(rows, cols))\n", | |
"\n", | |
"\n", | |
"def make_matching(pairs):\n", | |
" pairs_sorted = sorted(pairs)\n", | |
" rows = [i for i, j in pairs_sorted]\n", | |
" cols = [j for i, j in pairs_sorted]\n", | |
" return rows, cols" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"solve_funcs = {\n", | |
" 'lap': solve_lap,\n", | |
" # 'lapmod': solve_lapmod,\n", | |
" 'lapjv': solve_lapjv,\n", | |
" 'lapsolver': solve_lapsolver,\n", | |
" # 'munkres': solve_munkres,\n", | |
" 'ortools': solve_ortools,\n", | |
" 'scipy': solve_scipy,\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"----------------------------------------\n", | |
"balanced_perfect\n", | |
"----------------------------------------\n", | |
"[[inf 1. 2.]\n", | |
" [ 3. 5. inf]\n", | |
" [ 4. inf inf]]\n", | |
"\n", | |
"lap\n", | |
"correct! size 3\n", | |
"\n", | |
"lapjv\n", | |
"correct! size 3\n", | |
"\n", | |
"lapsolver\n", | |
"correct! size 3\n", | |
"\n", | |
"ortools\n", | |
"correct! size 3\n", | |
"\n", | |
"scipy\n", | |
"correct! size 3\n", | |
"\n", | |
"----------------------------------------\n", | |
"balanced_imperfect\n", | |
"----------------------------------------\n", | |
"[[inf 1. 2.]\n", | |
" [ 3. inf inf]\n", | |
" [ 4. inf inf]]\n", | |
"\n", | |
"lap\n", | |
"exception: ('non-finite cost output', inf)\n", | |
"\n", | |
"lapjv\n", | |
"exception: ('non-finite cost output', inf)\n", | |
"\n", | |
"lapsolver\n", | |
"correct! size 2\n", | |
"\n", | |
"ortools\n", | |
"exception: infeasible\n", | |
"\n", | |
"scipy\n", | |
"exception: cost matrix is infeasible\n", | |
"\n", | |
"----------------------------------------\n", | |
"unbalanced_onesided\n", | |
"----------------------------------------\n", | |
"[[inf 2. 1.]\n", | |
" [-1. inf inf]]\n", | |
"\n", | |
"lap\n", | |
"correct! size 2\n", | |
"\n", | |
"lapjv\n", | |
"exception: \"cost_matrix\" must be a square 2D numpy array\n", | |
"\n", | |
"lapsolver\n", | |
"correct! size 2\n", | |
"\n", | |
"ortools\n", | |
"exception: infeasible\n", | |
"\n", | |
"scipy\n", | |
"correct! size 2\n", | |
"\n", | |
"----------------------------------------\n", | |
"unbalanced_imperfect\n", | |
"----------------------------------------\n", | |
"[[inf 2. 1. 3.]\n", | |
" [-1. inf inf inf]\n", | |
" [ 0. inf inf inf]]\n", | |
"\n", | |
"lap\n", | |
"exception: ('non-finite cost output', inf)\n", | |
"\n", | |
"lapjv\n", | |
"exception: \"cost_matrix\" must be a square 2D numpy array\n", | |
"\n", | |
"lapsolver\n", | |
"correct! size 2\n", | |
"\n", | |
"ortools\n", | |
"exception: infeasible\n", | |
"\n", | |
"scipy\n", | |
"exception: cost matrix is infeasible\n" | |
] | |
} | |
], | |
"source": [ | |
"for problem_name, problem in problems.items():\n", | |
" print()\n", | |
" print('----------------------------------------')\n", | |
" print(problem_name)\n", | |
" print('----------------------------------------')\n", | |
" print(problem.costs)\n", | |
" for solver_name, solver in sorted(solve_funcs.items()):\n", | |
" print()\n", | |
" print(solver_name)\n", | |
" try:\n", | |
" cost_output, (rows, cols) = solver(problem.costs)\n", | |
" if cost_output is not None and not np.isfinite(cost_output):\n", | |
" raise ValueError('non-finite cost output', cost_output)\n", | |
" cost = np.sum(problem.costs[rows, cols])\n", | |
" except (ValueError, AssertionError) as ex:\n", | |
" print('exception:', ex)\n", | |
" else:\n", | |
" if cost != problem.min_cost:\n", | |
" print('wrong: actual {:g}, desired {:g}, edges {:s}'.format(\n", | |
" cost, problem.min_cost, list(zip(rows, cols))))\n", | |
" else:\n", | |
" print('correct! size', len(rows))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
wheel | |
numpy | |
scipy | |
ortools | |
lap | |
lapjv | |
lapsolver | |
munkres | |
munkres3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you!