Created
September 3, 2017 05:36
-
-
Save mblondel/fde75e4adf65453aa14d98dac62bc873 to your computer and use it in GitHub Desktop.
Optimal transport dual LP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Mathieu Blondel | |
# License: BSD 3 clause | |
import numpy as np | |
from scipy.optimize import linprog | |
def dual_lp(a, b, C, verbose=0): | |
"""Solves the dual optimal transport problem: | |
max <a, alpha> + <b, beta> s.t. alpha_i + beta_j <= C_{i,j} | |
""" | |
m = len(a) | |
n = len(b) | |
c = np.concatenate((a, b)) | |
c *= -1 # maximization problem | |
# Build alpha_i + beta_j <= C_{i,j} constraints. | |
A = np.zeros((m * n, m + n)) | |
b = np.zeros(m * n) | |
idx = 0 | |
for i in range(m): | |
for j in range(n): | |
A[idx, i] = 1 | |
A[idx, m + j] = 1 | |
b[idx] = C[i, j] | |
idx += 1 | |
# Needs this equality constraint to make the problem bounded. | |
A_eq = np.zeros((1, m + n)) | |
b_eq = np.zeros(1) | |
A_eq[0, :m] = 1 | |
res = linprog(c, A, b, A_eq, b_eq, bounds=(None, None)) | |
if verbose: | |
print("success:", res.success) | |
print("status:", res.status) | |
alpha = res.x[:m] | |
beta = res.x[m:] | |
return alpha, beta |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment