Skip to content

Instantly share code, notes, and snippets.

@lichengunc
Created April 27, 2018 20:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lichengunc/5c9ebdbdb2f4e2877134cd0ee84b4b36 to your computer and use it in GitHub Desktop.
Save lichengunc/5c9ebdbdb2f4e2877134cd0ee84b4b36 to your computer and use it in GitHub Desktop.
LanguageRankingCriterion
"""
LanguageRankingCriterion: takes [logp0, logp1] as input computing the ranking loss.
"""
class LanguageRankingCriterion(nn.Module):
def __init__(self, margin=1.):
super(LanguageRankingCriterion, self).__init__()
self.margin = margin
def forward(self, logprobs, target):
"""
Inputs:
- logprobs : [logp0, logp1], (2N, L, M)
- taget : [label0, label1], (2N, LL), where LL >= L
We split logprobs into two pieces, then compute the max-margin loss.
Output:
- loss : max(0, margin + F(logp1, label1) - F(lop0, label0) )
"""
# dimensions
N = logprobs.size(0) // 2
L = logprobs.size(1)
vocab_size = logprobs.size(2)
dtype = logprobs.data.type()
# logprobs = [logp0, logp1]
logp0 = logprobs[:N] # logp0 (N, L, M)
logp1 = logprobs[N:] # logp1 (N, L, M)
# chunk target by L
target = target[:, :L] # (2N, L)
target = target.contiguous() # chunking make it not contiguous anymore.
tgt0 = target[:N] # (N, L)
tgt1 = target[N:] # (N, L)
# compute log-likelihood for (logp0, tgt0)
logll0 = torch.gather(logp0.view(-1, vocab_size), # (NL, M)
1, tgt0.view(-1, 1)) # (NL, 1)
logll0 = logll0.view(N, L) # (N, L)
mask0 = Variable((tgt0 > 0).data.type(dtype), requires_grad=False) # (N, L) mask out <PAD>
logll0 = (logll0 * mask0).sum(1) / mask0.sum(1) # (N, )
# compute log-likelihood for (logp1, tgt1)
logll1 = torch.gather(logp1.view(-1, vocab_size), # (NL, M)
1, tgt1.view(-1, 1)) # (NL, 1)
logll1 = logll1.view(N, L) # (N, L)
mask1 = Variable((tgt1 > 0).data.type(dtype), requires_grad=False) # (N, L) mask out <PAD>
logll1 = (logll1 * mask1).sum(1) / mask1.sum(1) # (N, )
# max-margin ranking loss
output = self.margin + logll1 - logll0
zeros = Variable(output.data.new(1, N).zero_()) # (1, N)
output = torch.max(output, zeros) #(1, N)
loss = output.sum() / N
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment