Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created March 15, 2021 15:19
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasahle/48e9b3f17ead6c3ef11325f25de3655e to your computer and use it in GitHub Desktop.
Save thomasahle/48e9b3f17ead6c3ef11325f25de3655e to your computer and use it in GitHub Desktop.
Code from Differentiable Top-k Operator with Optimal Transport
def sinkhorn_forward(C, mu, nu, epsilon, max_iter):
bs, n, k_ = C.size()
v = torch.ones([bs, 1, k_])/(k_)
G = torch.exp(-C/epsilon)
if torch.cuda.is_available():
v = v.cuda()
for i in range(max_iter):
u = mu/(G*v).sum(-1, keepdim=True)
v = nu/(G*u).sum(-2, keepdim=True)
Gamma = u*G*v
return Gamma
def sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter):
bs, n, k_ = C.size()
k = k_-1
f = torch.zeros([bs, n, 1])
g = torch.zeros([bs, 1, k+1])
if torch.cuda.is_available():
f = f.cuda()
g = g.cuda()
epsilon_log_mu = epsilon*torch.log(mu)
epsilon_log_nu = epsilon*torch.log(nu)
def min_epsilon_row(Z, epsilon):
return -epsilon*torch.logsumexp((-Z)/epsilon, -1, keepdim=True)
def min_epsilon_col(Z, epsilon):
return -epsilon*torch.logsumexp((-Z)/epsilon, -2, keepdim=True)
for i in range(max_iter):
f = min_epsilon_row(C-g, epsilon)+epsilon_log_mu
g = min_epsilon_col(C-f, epsilon)+epsilon_log_nu
Gamma = torch.exp((-C+f+g)/epsilon)
return Gamma
def sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon):
nu_ = nu[:,:,:-1]
Gamma_ = Gamma[:,:,:-1]
bs, n, k_ = Gamma.size()
inv_mu = 1./(mu.view([1,-1])) #[1, n]
Kappa = torch.diag_embed(nu_.squeeze(-2)) \
-torch.matmul(Gamma_.transpose(-1, -2) * inv_mu.unsqueeze(-2), Gamma_) #[bs, k, k]
inv_Kappa = torch.inverse(Kappa) #[bs, k, k]
Gamma_mu = inv_mu.unsqueeze(-1)*Gamma_
L = Gamma_mu.matmul(inv_Kappa) #[bs, n, k]
G1 = grad_output_Gamma * Gamma #[bs, n, k+1]
g1 = G1.sum(-1)
G21 = (g1*inv_mu).unsqueeze(-1)*Gamma #[bs, n, k+1]
g1_L = g1.unsqueeze(-2).matmul(L) #[bs, 1, k]
G22 = g1_L.matmul(Gamma_mu.transpose(-1,-2)).transpose(-1,-2)*Gamma #[bs, n, k+1]
G23 = - F.pad(g1_L, pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1]
G2 = G21 + G22 + G23 #[bs, n, k+1]
del g1, G21, G22, G23, Gamma_mu
g2 = G1.sum(-2).unsqueeze(-1) #[bs, k+1, 1]
g2 = g2[:,:-1,:] #[bs, k, 1]
G31 = - L.matmul(g2)*Gamma #[bs, n, k+1]
G32 = F.pad(inv_Kappa.matmul(g2).transpose(-1,-2), pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1]
G3 = G31 + G32 #[bs, n, k+1]
grad_C = (-G1+G2+G3)/epsilon #[bs, n, k+1]
return grad_C
class TopKFunc(Function):
@staticmethod
def forward(ctx, C, mu, nu, epsilon, max_iter):
with torch.no_grad():
if epsilon>1e-2:
Gamma = sinkhorn_forward(C, mu, nu, epsilon, max_iter)
if bool(torch.any(Gamma!=Gamma)):
print('Nan appeared in Gamma, re-computing...')
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
else:
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
ctx.save_for_backward(mu, nu, Gamma)
ctx.epsilon = epsilon
return Gamma
@staticmethod
def backward(ctx, grad_output_Gamma):
epsilon = ctx.epsilon
mu, nu, Gamma = ctx.saved_tensors
# mu [1, n, 1]
# nu [1, 1, k+1]
#Gamma [bs, n, k+1]
with torch.no_grad():
grad_C = sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon)
return grad_C, None, None, None, None
class TopK_custom(torch.nn.Module):
def __init__(self, k, epsilon=0.1, max_iter = 200):
super(TopK_custom1, self).__init__()
self.k = k
self.epsilon = epsilon
self.anchors = torch.FloatTensor([k-i for i in range(k+1)]).view([1,1, k+1])
self.max_iter = max_iter
if torch.cuda.is_available():
self.anchors = self.anchors.cuda()
def forward(self, scores):
bs, n = scores.size()
scores = scores.view([bs, n, 1])
#find the -inf value and replace it with the minimum value except -inf
scores_ = scores.clone().detach()
max_scores = torch.max(scores_).detach()
scores_[scores_==float('-inf')] = float('inf')
min_scores = torch.min(scores_).detach()
filled_value = min_scores - (max_scores-min_scores)
mask = scores==float('-inf')
scores = scores.masked_fill(mask, filled_value)
C = (scores-self.anchors)**2
C = C / (C.max().detach())
mu = torch.ones([1, n, 1], requires_grad=False)/n
nu = [1./n for _ in range(self.k)]
nu.append((n-self.k)/n)
nu = torch.FloatTensor(nu).view([1, 1, self.k+1])
if torch.cuda.is_available():
mu = mu.cuda()
nu = nu.cuda()
Gamma = TopKFunc.apply(C, mu, nu, self.epsilon, self.max_iter)
A = Gamma[:,:,:self.k]*n
return A, None
@thomasahle
Copy link
Author

The way I understand it, the method returns a matrix of (item, position) probabilities.
For example, if I take the probabilities p = [0.01, 0.1, 0.04, 0.5, 0.24] and compute the top-3:

    ps2, _ = TopK_custom(3)(torch.log(p))
    print(f'{ps2=}')
    print(ps2.sum(dim=2))

I get

ps2=tensor([[[0.0307, 0.0824, 0.1877],
         [0.1803, 0.2183, 0.2244],
         [0.0960, 0.1596, 0.2252],
         [0.4005, 0.2780, 0.1638],
         [0.2926, 0.2618, 0.1988]]])
tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]])

The matrix says that the first item (with initial probability 0.01) has 3% chance being first (index 1), 8% chance being at index 2 and so on.
If you sum along the second axis you get the inclusion probabilities.

If you are just interested in differentiable inclusion probabilities, you can use my simple differentiable top-k code here: https://gist.github.com/thomasahle/4c1e85e5842d01b007a8d10f5fed3a18

@JoelNiklaus
Copy link

Great, thank you!

@Hosein47
Copy link

Hosein47 commented Mar 9, 2023

Is this function invertible?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment