Skip to content

Instantly share code, notes, and snippets.

@louity
Last active February 24, 2020 15:24
Show Gist options
  • Save louity/0629e2d12ff4ac96573b3a541246e162 to your computer and use it in GitHub Desktop.
Save louity/0629e2d12ff4ac96573b3a541246e162 to your computer and use it in GitHub Desktop.
Minimal logsumexp sinkhorn
def sinkhorn_logsumexp(cost_matrix, reg=1e-1, maxiter=30, momentum=0.):
"""Log domain version on sinkhorn distance algorithm ( https://arxiv.org/abs/1306.0895 ).
Inspired by https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py ."""
m, n = cost_matrix.size()
mu = torch.FloatTensor(m).fill_(1./m)
nu = torch.FloatTensor(n).fill_(1./n)
if torch.cuda.is_available():
mu, nu = mu.cuda(), nu.cuda()
def M(u, v):
"Modified cost for logarithmic updates"
"$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
return (-cost_matrix + u.unsqueeze(1) + v.unsqueeze(0)) / reg
u, v = 0. * mu, 0. * nu
# Actual Sinkhorn loop
for i in range(maxiter):
u1, v1 = u, v
u = reg * (torch.log(mu) - torch.logsumexp(M(u, v), dim=1)) + u
v = reg * (torch.log(nu) - torch.logsumexp(M(u, v).t(), dim=1)) + v
if momentum > 0.:
u = -momentum * u1 + (1+momentum) * u
v = -momentum * v1 + (1+momentum) * v
pi = torch.exp(M(u, v)) # Transport plan pi = diag(a)*K*diag(b)
cost = torch.sum(pi * cost_matrix) # Sinkhorn cost
return cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment