Skip to content

Instantly share code, notes, and snippets.

@pbamotra
Created April 10, 2019 04:08
Show Gist options
  • Save pbamotra/13788e344b5ed839216a0e80a8c09b37 to your computer and use it in GitHub Desktop.
Save pbamotra/13788e344b5ed839216a0e80a8c09b37 to your computer and use it in GitHub Desktop.
Pytorch implementation of sigsoftmax - https://arxiv.org/pdf/1805.10829.pdf
def logsigsoftmax(logits):
"""
Computes sigsoftmax from the paper - https://arxiv.org/pdf/1805.10829.pdf
"""
max_values = torch.max(logits, 1, keepdim = True)[0]
exp_logits_sigmoided = torch.exp(logits - max_values) * torch.sigmoid(logits)
sum_exp_logits_sigmoided = exp_logits_sigmoided.sum(1, keepdim = True)
log_probs = logits - max_values + torch.log(torch.sigmoid(logits)) - torch.log(sum_exp_logits_sigmoided)
return log_probs
@quettabit
Copy link

@pbamotra May I ask why did you calculate max_values and what is the purpose of it? I don't see that in the definition mentioned in the paper.

@pbamotra
Copy link
Author

@Stonesjtu
Copy link

I think you can also get the log_probs by:

sigmoid_logits = logits.sigmoid().log()
sigsoftmax_logits = logits + sigmoid_logits
return sigsoftmax_logits.log_softmax()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment