Skip to content

Instantly share code, notes, and snippets.

@hardik2396
Last active August 13, 2018 08:00
Show Gist options
  • Save hardik2396/757ebb875edc7d0ef9866d00f51a8e50 to your computer and use it in GitHub Desktop.
Save hardik2396/757ebb875edc7d0ef9866d00f51a8e50 to your computer and use it in GitHub Desktop.
class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self, size, padding_idx, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(size_average=False)
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
def forward(self, x, target):
assert x.size(1) == self.size
true_dist = x.data.clone()
true_dist.fill_(self.smoothing / (self.size - 2))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
true_dist[:, self.padding_idx] = 0
mask = torch.nonzero(target.data == self.padding_idx)
if mask.dim() > 0:
true_dist.index_fill_(0, mask.squeeze(), 0.0)
self.true_dist = true_dist
return self.criterion(x, Variable(true_dist, requires_grad=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment