Skip to content

Instantly share code, notes, and snippets.

@jvlmdr
Last active April 17, 2023 00:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jvlmdr/cfdef3de27653bed0aec93bbfb41481c to your computer and use it in GitHub Desktop.
Save jvlmdr/cfdef3de27653bed0aec93bbfb41481c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"\n",
"import lap\n",
"import lapjv\n",
"import lapsolver\n",
"import munkres\n",
"import numpy as np\n",
"import ortools.graph.pywrapgraph\n",
"import scipy.optimize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"Problem = collections.namedtuple('Problem', ['costs', 'min_cost'])\n",
"\n",
"X = np.inf\n",
"problems = collections.OrderedDict()\n",
"problems['balanced_perfect'] = Problem(\n",
" np.array([[X, 1, 2],\n",
" [3, 5, X],\n",
" [4, X, X]]),\n",
" 2 + 5 + 4,\n",
")\n",
"problems['balanced_imperfect'] = Problem(\n",
" np.array([[X, 1, 2],\n",
" [3, X, X],\n",
" [4, X, X]]),\n",
" 1 + 3,\n",
")\n",
"problems['unbalanced_onesided'] = Problem(\n",
" np.array([[X, 2, 1],\n",
" [-1, X, X]]),\n",
" -1 + 1,\n",
")\n",
"problems['unbalanced_imperfect'] = Problem(\n",
" np.array([[X, 2, 1, 3],\n",
" [-1, X, X, X],\n",
" [0, X, X, X]]),\n",
" -1 + 1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def solve_lap(costs):\n",
" is_square = (costs.shape[0] == costs.shape[1])\n",
" min_cost, row_to_col, col_to_row = lap.lapjv(costs, extend_cost=(not is_square))\n",
" return min_cost, make_matching(enumerate(c for c in row_to_col if c >= 0))\n",
"\n",
"\n",
"def solve_lapmod(costs):\n",
" raise NotImplementedError()\n",
"\n",
"\n",
"def solve_lapjv(costs):\n",
" costs = np.where(costs < np.inf, costs, np.inf)\n",
" row_to_col, col_to_row, (min_cost, _, _) = lapjv.lapjv(costs)\n",
" return min_cost, make_matching(enumerate(c for c in row_to_col if c >= 0))\n",
"\n",
"\n",
"def solve_lapsolver(costs):\n",
" rows, cols = lapsolver.solve_dense(costs)\n",
" return None, make_matching(zip(rows, cols))\n",
"\n",
"\n",
"def solve_munkres(costs):\n",
" costs = [\n",
" [(x if x < np.inf else munkres.DISALLOWED) for x in row]\n",
" for row in costs\n",
" ]\n",
" m = munkres.Munkres()\n",
" try:\n",
" pairs = m.compute(costs)\n",
" except munkres.UnsolvableMatrix as ex:\n",
" raise ValueError(ex)\n",
" return make_matching(pairs)\n",
"\n",
"\n",
"def solve_ortools(costs):\n",
" assignment = ortools.graph.pywrapgraph.LinearSumAssignment()\n",
" for (i, j), cost in np.ndenumerate(costs):\n",
" if cost < np.inf:\n",
" assert cost == int(cost)\n",
" assignment.AddArcWithCost(i, j, int(cost))\n",
" solve_status = assignment.Solve()\n",
" if solve_status == assignment.INFEASIBLE:\n",
" raise ValueError('infeasible')\n",
" if solve_status != assignment.OPTIMAL:\n",
" raise ValueError('not optimal', solve_status)\n",
" n = assignment.NumNodes()\n",
" min_cost = assignment.OptimalCost()\n",
" pairs = [(i, assignment.RightMate(i)) for i in range(n)]\n",
" return min_cost, make_matching(pairs)\n",
"\n",
"\n",
"def solve_scipy(costs):\n",
" costs = np.where(costs < np.inf, costs, np.inf)\n",
" rows, cols = scipy.optimize.linear_sum_assignment(costs)\n",
" return None, make_matching(zip(rows, cols))\n",
"\n",
"\n",
"def make_matching(pairs):\n",
" pairs_sorted = sorted(pairs)\n",
" rows = [i for i, j in pairs_sorted]\n",
" cols = [j for i, j in pairs_sorted]\n",
" return rows, cols"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"solve_funcs = {\n",
" 'lap': solve_lap,\n",
" # 'lapmod': solve_lapmod,\n",
" 'lapjv': solve_lapjv,\n",
" 'lapsolver': solve_lapsolver,\n",
" # 'munkres': solve_munkres,\n",
" 'ortools': solve_ortools,\n",
" 'scipy': solve_scipy,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"----------------------------------------\n",
"balanced_perfect\n",
"----------------------------------------\n",
"[[inf 1. 2.]\n",
" [ 3. 5. inf]\n",
" [ 4. inf inf]]\n",
"\n",
"lap\n",
"correct! size 3\n",
"\n",
"lapjv\n",
"correct! size 3\n",
"\n",
"lapsolver\n",
"correct! size 3\n",
"\n",
"ortools\n",
"correct! size 3\n",
"\n",
"scipy\n",
"correct! size 3\n",
"\n",
"----------------------------------------\n",
"balanced_imperfect\n",
"----------------------------------------\n",
"[[inf 1. 2.]\n",
" [ 3. inf inf]\n",
" [ 4. inf inf]]\n",
"\n",
"lap\n",
"exception: ('non-finite cost output', inf)\n",
"\n",
"lapjv\n",
"exception: ('non-finite cost output', inf)\n",
"\n",
"lapsolver\n",
"correct! size 2\n",
"\n",
"ortools\n",
"exception: infeasible\n",
"\n",
"scipy\n",
"exception: cost matrix is infeasible\n",
"\n",
"----------------------------------------\n",
"unbalanced_onesided\n",
"----------------------------------------\n",
"[[inf 2. 1.]\n",
" [-1. inf inf]]\n",
"\n",
"lap\n",
"correct! size 2\n",
"\n",
"lapjv\n",
"exception: \"cost_matrix\" must be a square 2D numpy array\n",
"\n",
"lapsolver\n",
"correct! size 2\n",
"\n",
"ortools\n",
"exception: infeasible\n",
"\n",
"scipy\n",
"correct! size 2\n",
"\n",
"----------------------------------------\n",
"unbalanced_imperfect\n",
"----------------------------------------\n",
"[[inf 2. 1. 3.]\n",
" [-1. inf inf inf]\n",
" [ 0. inf inf inf]]\n",
"\n",
"lap\n",
"exception: ('non-finite cost output', inf)\n",
"\n",
"lapjv\n",
"exception: \"cost_matrix\" must be a square 2D numpy array\n",
"\n",
"lapsolver\n",
"correct! size 2\n",
"\n",
"ortools\n",
"exception: infeasible\n",
"\n",
"scipy\n",
"exception: cost matrix is infeasible\n"
]
}
],
"source": [
"for problem_name, problem in problems.items():\n",
" print()\n",
" print('----------------------------------------')\n",
" print(problem_name)\n",
" print('----------------------------------------')\n",
" print(problem.costs)\n",
" for solver_name, solver in sorted(solve_funcs.items()):\n",
" print()\n",
" print(solver_name)\n",
" try:\n",
" cost_output, (rows, cols) = solver(problem.costs)\n",
" if cost_output is not None and not np.isfinite(cost_output):\n",
" raise ValueError('non-finite cost output', cost_output)\n",
" cost = np.sum(problem.costs[rows, cols])\n",
" except (ValueError, AssertionError) as ex:\n",
" print('exception:', ex)\n",
" else:\n",
" if cost != problem.min_cost:\n",
" print('wrong: actual {:g}, desired {:g}, edges {:s}'.format(\n",
" cost, problem.min_cost, list(zip(rows, cols))))\n",
" else:\n",
" print('correct! size', len(rows))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
wheel
numpy
scipy
ortools
lap
lapjv
lapsolver
munkres
munkres3
@matangover
Copy link

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment