Created
March 21, 2022 16:12
-
-
Save jhelgert/9d6859ec3ae5926001276b917e2a4434 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gurobipy import Model, GRB, quicksum as qsum | |
from scipy.optimize import linear_sum_assignment | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import time | |
def build_LP(c): | |
mdl = Model() | |
# parameters | |
n_workers, n_tasks = c.shape | |
# Variables | |
x = mdl.addVars(n_workers, n_tasks, vtype="B") | |
# Indices | |
Indices = ((i, j) for i in range(n_workers) for j in range(n_tasks)) | |
# Objective | |
mdl.setObjective(qsum(c[i,j]*x[i,j] for (i, j) in Indices), sense=GRB.MINIMIZE) | |
# each worker gets assigned exactly one task | |
for w in range(n_workers): | |
mdl.addConstr(qsum(x[w, j] for j in range(n_tasks)) == 1) | |
# each task can only be assigned once | |
for task in range(n_tasks): | |
mdl.addConstr(qsum(x[w, task] for w in range(n_workers)) == 1) | |
return mdl | |
def benchmark(ns, repeats=10): | |
timings = np.zeros((2, ns.size)) | |
for i, n in enumerate(ns): | |
cost = np.random.randint(500, 999, size=(n,n)) | |
for _ in range(repeats): | |
# time linear_sum_assignment | |
start = time.time() | |
row_ind, col_ind = linear_sum_assignment(cost) | |
timings[0, i] += time.time() - start | |
# time solving the LP (ignore the time needed to build the model) | |
mdl = build_LP(cost) | |
start = time.time() | |
mdl.optimize() | |
timings[1, i] += time.time() - start | |
return timings / repeats | |
if __name__ == '__main__': | |
ns = np.array([100, 200, 300, 400, 500]) | |
timings = benchmark(ns, repeats = 10) | |
fig, ax = plt.subplots() | |
ax.plot(ns, timings[0], label="Hungarian") | |
ax.plot(ns, timings[1], label="LP (gurobipy)") | |
ax.set_xlabel("n") | |
ax.set_ylabel("time [s]") | |
ax.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment