Skip to content

Instantly share code, notes, and snippets.

@mblondel
Created September 9, 2022 07:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mblondel/43cacab94d7171e59f64a6603578cd97 to your computer and use it in GitHub Desktop.
Save mblondel/43cacab94d7171e59f64a6603578cd97 to your computer and use it in GitHub Desktop.
Reconstruct transportation plan from dual potentials
# Author: Mathieu Blondel
# License: BSD
import numpy as onp
import scipy
import jaxopt
import jax.numpy as jnp
def get_plan_from_alpha(a, b, C, alpha, eps=1e-10, method="pinv"):
beta = onp.min(C - alpha[:, onp.newaxis], axis=0)
diff = C - alpha[:, onp.newaxis] - beta
support = (diff <= eps).astype(int)
rows, cols = onp.nonzero(support)
ab = onp.concatenate((a, b))
mn = len(a) + len(b)
M = onp.zeros((mn, mn))
for n in range(len(rows)):
i, j = rows[n], cols[n]
M[i, n] = 1
M[j + len(a), n] = 1
if method == "pinv":
Minv = scipy.linalg.pinv(M)
t = onp.dot(Minv, ab)
elif method == "nnls":
t = scipy.optimize.nnls(M, ab, maxiter=1000)[0]
elif method == "pg":
def fun(x):
res = jnp.dot(M, x) - ab
return jnp.dot(res, res)
projection = jaxopt.projection.projection_simplex
pg = jaxopt.ProjectedGradient(fun=fun, projection=projection)
init = jnp.ones(mn) / mn
t = pg.run(init).params
else:
raise ValueError("Unknown method.")
# Reconstruct the full plan matrix.
T = onp.zeros((len(a), len(b)))
for n in range(len(rows)):
i, j = rows[n], cols[n]
T[i, j] = t[n]
return T
if __name__ == '__main__':
import ot
# Fake data.
rng = onp.random.RandomState(0)
a = rng.rand(3)
a /= a.sum()
b = rng.rand(4)
b /= b.sum()
C = rng.rand(3, 4)
# Solve the OT problem.
T, dic = ot.emd(a, b, C, log=True)
alpha, beta = dic["u"], dic["v"]
print("POT:")
print(T)
print("Reconstructed:")
print(get_plan_from_alpha(a, b, C, alpha))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment