Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Created November 13, 2017 16:44
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save ajbrock/8b9415fddaa516d63b61067af2f4ce68 to your computer and use it in GitHub Desktop.
Save ajbrock/8b9415fddaa516d63b61067af2f4ce68 to your computer and use it in GitHub Desktop.
# PyTorch code For implementing the mixture of softmaxes layer from
# "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model"
# https://arxiv.org/abs/1711.03953
context = self.fc(out)
# Non-log version
priors = F.softmax(context[:,-self.n_components:])
mixtures = torch.stack([priors[:,i].unsqueeze(1) * F.softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(self.n_components)],1)
out = torch.log(mixtures.sum(1))
# Log version
# log_priors = F.log_softmax(context[:,-self.num_components:]).unsqueeze(2)
# log_mixtures = torch.stack([F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
# log_priors = F.log_softmax(context[:,-self.num_components:])
# log_mixtures = torch.stack([log_priors[:,i] + F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
# out = torch.log(torch.exp(log_priors + log_mixtures).sum(1))
@Smerity
Copy link

Smerity commented Nov 14, 2017

For the log version, you can use the logsumexp trick to avoid numerical stability issues =]

https://en.wikipedia.org/wiki/LogSumExp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment