Skip to content

Instantly share code, notes, and snippets.

@vyraun
Forked from ajbrock/Mixture_of_softmaxes.py
Created June 6, 2020 04:50
Show Gist options
  • Save vyraun/37121ad3e9bc24e306c00a15a9fa5e05 to your computer and use it in GitHub Desktop.
Save vyraun/37121ad3e9bc24e306c00a15a9fa5e05 to your computer and use it in GitHub Desktop.
# PyTorch code For implementing the mixture of softmaxes layer from
# "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model"
# https://arxiv.org/abs/1711.03953
context = self.fc(out)
# Non-log version
priors = F.softmax(context[:,-self.n_components:])
mixtures = torch.stack([priors[:,i].unsqueeze(1) * F.softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(self.n_components)],1)
out = torch.log(mixtures.sum(1))
# Log version
# log_priors = F.log_softmax(context[:,-self.num_components:]).unsqueeze(2)
# log_mixtures = torch.stack([F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
# log_priors = F.log_softmax(context[:,-self.num_components:])
# log_mixtures = torch.stack([log_priors[:,i] + F.log_softmax(context[:, i * self.nClasses : (i + 1) * self.nClasses]) for i in range(num_components)],1)
# out = torch.log(torch.exp(log_priors + log_mixtures).sum(1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment