Skip to content

Instantly share code, notes, and snippets.

@rosinality
Last active July 21, 2021 03:42
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save rosinality/0cdd8d6adb8463961f50bd1845faddf8 to your computer and use it in GitHub Desktop.
Save rosinality/0cdd8d6adb8463961f50bd1845faddf8 to your computer and use it in GitHub Desktop.
Adaptive Softmax implementation for PyTorch
import torch
from torch import nn
from torch.autograd import Variable
class AdaptiveSoftmax(nn.Module):
def __init__(self, input_size, cutoff):
super().__init__()
self.input_size = input_size
self.cutoff = cutoff
self.output_size = cutoff[0] + len(cutoff) - 1
self.head = nn.Linear(input_size, self.output_size)
self.tail = nn.ModuleList()
for i in range(len(cutoff) - 1):
seq = nn.Sequential(
nn.Linear(input_size, input_size // 4 ** i, False),
nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False)
)
self.tail.append(seq)
def reset(self, init=0.1):
self.head.weight.data.uniform_(-init, init)
for tail in self.tail:
tail[0].weight.data.uniform_(-init, init)
tail[1].weight.data.uniform_(-init, init)
def set_target(self, target):
self.id = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
if mask.sum() > 0:
self.id.append(Variable(mask.float().nonzero().squeeze(1)))
else:
self.id.append(None)
def forward(self, input):
output = [self.head(input)]
for i in range(len(self.id)):
if self.id[i] is not None:
output.append(self.tail[i](input.index_select(0, self.id[i])))
else:
output.append(None)
return output
def log_prob(self, input):
lsm = nn.LogSoftmax().cuda()
head_out = self.head(input)
batch_size = head_out.size(0)
prob = torch.zeros(batch_size, self.cutoff[-1]).cuda()
lsm_head = lsm(head_out)
prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
for i in range(len(self.tail)):
pos = self.cutoff[i]
i_size = self.cutoff[i + 1] - pos
buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1)
buffer = buffer.expand(batch_size, i_size)
lsm_tail = lsm(self.tail[i](input))
prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data)
return prob
class AdaptiveLoss(nn.Module):
def __init__(self, cutoff):
super().__init__()
self.cutoff = cutoff
self.criterions = nn.ModuleList()
for i in self.cutoff:
self.criterions.append(nn.CrossEntropyLoss(size_average=False))
def remap_target(self, target):
new_target = [target.clone()]
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i
if mask.sum() > 0:
new_target.append(target[mask].add(-self.cutoff[i]))
else:
new_target.append(None)
return new_target
def forward(self, input, target):
batch_size = input[0].size(0)
target = self.remap_target(target.data)
output = 0.0
for i in range(len(input)):
if input[i] is not None:
assert(target[i].min() >= 0 and target[i].max() <= input[i].size(1))
criterion = self.criterions[i]
output += criterion(input[i], Variable(target[i]))
output /= batch_size
return output
@jerrybai1995
Copy link

Line 48: Shouldn't it be output.append(self.tail[i](input.index_select(1, self.id[i]))) instead of dim 0? I'm assuming input has dimension N x input_size.

@temporaer
Copy link

Line 108: Should be if input[i] is not None:

@rosinality
Copy link
Author

Sorry for late reply...I found that the comments on gist gives no notifications.

jerrybai1995: That's it. index_select works by selecting N-th elements on input sequence specified by id tensor.
temporaer: Yes, I corrected it.

@songyuzhou324
Copy link

I got error as follows:

File "text8.py", line 121, in
train()
File "text8.py", line 78, in train
loss = criterion(output, Y_var)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/AdaptiveSoftmaxPyTorch/adasoft.py", line 114, in forward
output += self.criterion(input[i], Variable(target[i]))
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in call
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 71, in forward
raise NotImplementedError
NotImplementedError

Any idea?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment