Skip to content

Instantly share code, notes, and snippets.

@janhuenermann

janhuenermann/sinkhorn.py

Last active May 18, 2021
Embed
What would you like to do?
Sinkhorn Optimal Transport Algorithm in PyTorch
import torch
@torch.jit.script
def log_optimal_transport(Z, iters: int):
m, n = Z.shape
log_mu = -torch.tensor(m).to(Z).log().expand(Z.shape[:-2] + [m])
log_nu = -torch.tensor(n).to(Z).log().expand(Z.shape[:-2] + [n])
u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
for _ in range(iters):
v = log_nu - torch.logsumexp(Z + u.unsqueeze(-1), dim=-2)
u = log_mu - torch.logsumexp(Z + v.unsqueeze(-2), dim=-1)
return Z + u.unsqueeze(-1) + v.unsqueeze(-2)
# Example
# Define score (note: score = -cost)
score = torch.tensor([
[5.0, -5.0, 5.0, 0.0, 0.0],
[0.0, 5.0, -5.0, 5.0, 0.0],
[0.0, 0.0, 5.0, -5.0, 5.0],
[0.0, 0.0, 0.0, 5.0, -5.0],
[0.0, 0.0, 0.0, 0.0, 5.0],
])
# Calculate optimal transport in log space
log_T = log_optimal_transport(score, 32)
# The optimal flow/transport is then T
T = log_T.exp()
print(T)
# [[0.1619, 0.0000, 0.0379, 0.0001, 0.0001],
# [0.0017, 0.1750, 0.0000, 0.0231, 0.0001],
# [0.0044, 0.0030, 0.1546, 0.0000, 0.0379],
# [0.0130, 0.0089, 0.0030, 0.1750, 0.0000],
# [0.0190, 0.0130, 0.0044, 0.0017, 0.1619]]
print("Col sum = {}".format(T.sum(-1)))
# Col sum = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
print("Row sum = {}".format(T.sum(-2)))
# Row sum = tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
# Earth movers distance is given by
EMD = torch.sum(cost * T)
print("EMD = {}".format(EMD))
# EMD = 4.637098789215088
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment