Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created March 15, 2021 15:19
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasahle/48e9b3f17ead6c3ef11325f25de3655e to your computer and use it in GitHub Desktop.
Save thomasahle/48e9b3f17ead6c3ef11325f25de3655e to your computer and use it in GitHub Desktop.
Code from Differentiable Top-k Operator with Optimal Transport
def sinkhorn_forward(C, mu, nu, epsilon, max_iter):
bs, n, k_ = C.size()
v = torch.ones([bs, 1, k_])/(k_)
G = torch.exp(-C/epsilon)
if torch.cuda.is_available():
v = v.cuda()
for i in range(max_iter):
u = mu/(G*v).sum(-1, keepdim=True)
v = nu/(G*u).sum(-2, keepdim=True)
Gamma = u*G*v
return Gamma
def sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter):
bs, n, k_ = C.size()
k = k_-1
f = torch.zeros([bs, n, 1])
g = torch.zeros([bs, 1, k+1])
if torch.cuda.is_available():
f = f.cuda()
g = g.cuda()
epsilon_log_mu = epsilon*torch.log(mu)
epsilon_log_nu = epsilon*torch.log(nu)
def min_epsilon_row(Z, epsilon):
return -epsilon*torch.logsumexp((-Z)/epsilon, -1, keepdim=True)
def min_epsilon_col(Z, epsilon):
return -epsilon*torch.logsumexp((-Z)/epsilon, -2, keepdim=True)
for i in range(max_iter):
f = min_epsilon_row(C-g, epsilon)+epsilon_log_mu
g = min_epsilon_col(C-f, epsilon)+epsilon_log_nu
Gamma = torch.exp((-C+f+g)/epsilon)
return Gamma
def sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon):
nu_ = nu[:,:,:-1]
Gamma_ = Gamma[:,:,:-1]
bs, n, k_ = Gamma.size()
inv_mu = 1./(mu.view([1,-1])) #[1, n]
Kappa = torch.diag_embed(nu_.squeeze(-2)) \
-torch.matmul(Gamma_.transpose(-1, -2) * inv_mu.unsqueeze(-2), Gamma_) #[bs, k, k]
inv_Kappa = torch.inverse(Kappa) #[bs, k, k]
Gamma_mu = inv_mu.unsqueeze(-1)*Gamma_
L = Gamma_mu.matmul(inv_Kappa) #[bs, n, k]
G1 = grad_output_Gamma * Gamma #[bs, n, k+1]
g1 = G1.sum(-1)
G21 = (g1*inv_mu).unsqueeze(-1)*Gamma #[bs, n, k+1]
g1_L = g1.unsqueeze(-2).matmul(L) #[bs, 1, k]
G22 = g1_L.matmul(Gamma_mu.transpose(-1,-2)).transpose(-1,-2)*Gamma #[bs, n, k+1]
G23 = - F.pad(g1_L, pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1]
G2 = G21 + G22 + G23 #[bs, n, k+1]
del g1, G21, G22, G23, Gamma_mu
g2 = G1.sum(-2).unsqueeze(-1) #[bs, k+1, 1]
g2 = g2[:,:-1,:] #[bs, k, 1]
G31 = - L.matmul(g2)*Gamma #[bs, n, k+1]
G32 = F.pad(inv_Kappa.matmul(g2).transpose(-1,-2), pad=(0, 1), mode='constant', value=0)*Gamma #[bs, n, k+1]
G3 = G31 + G32 #[bs, n, k+1]
grad_C = (-G1+G2+G3)/epsilon #[bs, n, k+1]
return grad_C
class TopKFunc(Function):
@staticmethod
def forward(ctx, C, mu, nu, epsilon, max_iter):
with torch.no_grad():
if epsilon>1e-2:
Gamma = sinkhorn_forward(C, mu, nu, epsilon, max_iter)
if bool(torch.any(Gamma!=Gamma)):
print('Nan appeared in Gamma, re-computing...')
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
else:
Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
ctx.save_for_backward(mu, nu, Gamma)
ctx.epsilon = epsilon
return Gamma
@staticmethod
def backward(ctx, grad_output_Gamma):
epsilon = ctx.epsilon
mu, nu, Gamma = ctx.saved_tensors
# mu [1, n, 1]
# nu [1, 1, k+1]
#Gamma [bs, n, k+1]
with torch.no_grad():
grad_C = sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon)
return grad_C, None, None, None, None
class TopK_custom(torch.nn.Module):
def __init__(self, k, epsilon=0.1, max_iter = 200):
super(TopK_custom1, self).__init__()
self.k = k
self.epsilon = epsilon
self.anchors = torch.FloatTensor([k-i for i in range(k+1)]).view([1,1, k+1])
self.max_iter = max_iter
if torch.cuda.is_available():
self.anchors = self.anchors.cuda()
def forward(self, scores):
bs, n = scores.size()
scores = scores.view([bs, n, 1])
#find the -inf value and replace it with the minimum value except -inf
scores_ = scores.clone().detach()
max_scores = torch.max(scores_).detach()
scores_[scores_==float('-inf')] = float('inf')
min_scores = torch.min(scores_).detach()
filled_value = min_scores - (max_scores-min_scores)
mask = scores==float('-inf')
scores = scores.masked_fill(mask, filled_value)
C = (scores-self.anchors)**2
C = C / (C.max().detach())
mu = torch.ones([1, n, 1], requires_grad=False)/n
nu = [1./n for _ in range(self.k)]
nu.append((n-self.k)/n)
nu = torch.FloatTensor(nu).view([1, 1, self.k+1])
if torch.cuda.is_available():
mu = mu.cuda()
nu = nu.cuda()
Gamma = TopKFunc.apply(C, mu, nu, self.epsilon, self.max_iter)
A = Gamma[:,:,:self.k]*n
return A, None
@JoelNiklaus
Copy link

Hi,
Thank you very much for providing this code.
Unfortunately, I don't see any imports, like in the paper (https://arxiv.org/pdf/2002.06504.pdf). Do you know how to import "Function" on line 77?

@thomasahle
Copy link
Author

Maybe torch.autograd.Function?

@JoelNiklaus
Copy link

Yes, that seemed to work. Thank you very much.

@JoelNiklaus
Copy link

I am actually interested in the topk indices and not the topk values. However, this function returns None for the indices on line 146. Do you know by any chance if it is possible to get the indices in a differentiable manner as well?

@thomasahle
Copy link
Author

The way I understand it, the method returns a matrix of (item, position) probabilities.
For example, if I take the probabilities p = [0.01, 0.1, 0.04, 0.5, 0.24] and compute the top-3:

    ps2, _ = TopK_custom(3)(torch.log(p))
    print(f'{ps2=}')
    print(ps2.sum(dim=2))

I get

ps2=tensor([[[0.0307, 0.0824, 0.1877],
         [0.1803, 0.2183, 0.2244],
         [0.0960, 0.1596, 0.2252],
         [0.4005, 0.2780, 0.1638],
         [0.2926, 0.2618, 0.1988]]])
tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]])

The matrix says that the first item (with initial probability 0.01) has 3% chance being first (index 1), 8% chance being at index 2 and so on.
If you sum along the second axis you get the inclusion probabilities.

If you are just interested in differentiable inclusion probabilities, you can use my simple differentiable top-k code here: https://gist.github.com/thomasahle/4c1e85e5842d01b007a8d10f5fed3a18

@JoelNiklaus
Copy link

Great, thank you!

@Hosein47
Copy link

Hosein47 commented Mar 9, 2023

Is this function invertible?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment