Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
My custom loss openNMT-py
class MyLoss(LossComputeBase):
"""
Standard NMT Loss Computation.
"""
def __init__(self, generator, tgt_vocab, normalization="sents",
label_smoothing=0.0):
super(MyLoss, self).__init__(generator, tgt_vocab)
pass
assert (label_smoothing >= 0.0 and label_smoothing <= 1.0)
if label_smoothing > 0:
# When label smoothing is turned on,
# KL-divergence between q_{smoothed ground truth prob.}(w)
# and p_{prob. computed by model}(w) is minimized.
# If label smoothing value is set to zero, the loss
# is equivalent to NLLLoss or CrossEntropyLoss.
# All non-true labels are uniformly set to low-confidence.
self.criterion = nn.KLDivLoss(size_average=False)
one_hot = torch.randn(1, len(tgt_vocab))
one_hot.fill_(label_smoothing / (len(tgt_vocab) - 2))
one_hot[0][self.padding_idx] = 0
self.register_buffer('one_hot', one_hot)
else:
weight = torch.ones(len(tgt_vocab))
weight[self.padding_idx] = 0
self.criterion = nn.NLLLoss(weight, size_average=False)
self.confidence = 1.0 - label_smoothing
self.print_cnt = 0
def _make_shard_state(self, batch, output, range_, attns=None):
return {
"output": output,
"target": batch.tgt[range_[0] + 1: range_[1]],
}
def _compute_loss(self, batch, output, target):
if target.data.shape[0] == 1:
scores = self.generator(self._bottle(output))
gtruth = target.view(-1)
if self.confidence < 1:
tdata = gtruth.data
mask = torch.nonzero(tdata.eq(self.padding_idx)).squeeze()
likelihood = torch.gather(scores.data, 1, tdata.unsqueeze(1))
tmp_ = self.one_hot.repeat(gtruth.size(0), 1)
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
if mask.dim() > 0:
likelihood.index_fill_(0, mask, 0)
tmp_.index_fill_(0, mask, 0)
gtruth = Variable(tmp_, requires_grad=False)
loss = self.criterion(scores, gtruth)
if self.confidence < 1:
loss_data = - likelihood.sum(0)
else:
loss_data = loss.data.clone()
stats = self._stats(loss_data, scores.data, target.view(-1).data)
return loss, stats
scores = self.generator(output.permute(1, 0, 2).contiguous())
target_new = target.permute(1, 0).contiguous()
# print("-" * 10 + "compute_loss" + "-" * 10)
sm = Softmax(dim=2)
probs = sm(scores)
def _get_len(a, padding_idx, eos_idx):
coincidence = (a == padding_idx) | (a == eos_idx)
ones = np.ones((a.shape[0], 1))
coincidence = np.concatenate((coincidence, ones), 1)
return np.argmax(coincidence, 1).tolist() # return first occurence
target_np = target_new.data.cpu().numpy()
length = _get_len(target_np, self.padding_idx, self.eos_idx)
length = np.clip(length, a_min=0, a_max=target_np.shape[1] - 1)
translation_lengths = _get_len(torch.max(probs, 2)[1].data.cpu().numpy(),\
self.padding_idx, self.eos_idx)
translation_lengths = np.clip(translation_lengths,\
a_min=probs.shape[1] - 1, a_max=probs.shape[1] - 1)
translation_lengths = LongTensor(translation_lengths)
bleu_loss, precisions = bleu(probs[:, :-1, :],\
target_new.data[:, :-1].tolist(), translation_lengths,\
length, max_order=4, smooth=True)
if self.print_cnt % 100 == 0:
print('precisions')
print(precisions)
print('bleu_loss')
print(bleu_loss)
self.print_cnt += 1
loss_data = bleu_loss.data
stats = self._stats(loss_data, \
self.generator(self._bottle(output)).data, target.view(-1).data)
return bleu_loss, stats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment