Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
PyTorch workaround for masking cross entropy loss
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.range(0, max_len - 1).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
def compute_loss(logits, target, length):
"""
Args:
logits: A Variable containing a FloatTensor of size
(batch, max_len, num_classes) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = functional.log_softmax(logits_flat)
# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
@jaypatravali
Copy link

jaypatravali commented May 2, 2017

Does Variable length has the size:BxHxW for an image in a batch. B- Batch size, Height - H , W- Width

Thanks,

@vijendra-rana
Copy link

vijendra-rana commented Jun 15, 2017

Great way thanks 👍

@viig99
Copy link

viig99 commented Jan 18, 2018

Hi, with the introduction of the reduce=False variable, what changes need to be done to the masked cross entropy to simplify it ?

@jihunchoi
Copy link
Author

jihunchoi commented Feb 13, 2018

@viig99
Hi, it seems that functional.cross_entropy still doesn't support >2D input.
I think log_softmax + gather calls can be merged into one cross_entropy call with reduce=False, and I expect there might be
some performance gain.
I will update this gist soon.

@emanjavacas
Copy link

emanjavacas commented Apr 4, 2018

I (genuinely) wonder how this is different from using:

weight = torch.ones(vocab_size)
weight[pad_idx] = 0.0
crit = nn.CrossEntropy(weight=weight)
crit(output, targets)

I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.

@lzfelix
Copy link

lzfelix commented Jul 24, 2018

Interesting solution, @emanjavacas!

@funkindy
Copy link

funkindy commented Oct 8, 2018

@emanjavacas, @lzfelix
yeah its a good one, and shouldnt this be even more "correct"?

criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

i hope this was created for masking and no other purpose?

@danielyan86129
Copy link

danielyan86129 commented Sep 9, 2021

I (genuinely) wonder how this is different from using:

weight = torch.ones(vocab_size)
weight[pad_idx] = 0.0
crit = nn.CrossEntropy(weight=weight)
crit(output, targets)

I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.

masking in the proposed gist is per-sample, i.e. telling you if each sample should have a loss or not, while your weight here is per-class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment