Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created January 20, 2022 11:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/a2beefed5e48aad5839756480d68239f to your computer and use it in GitHub Desktop.
Save e96031413/a2beefed5e48aad5839756480d68239f to your computer and use it in GitHub Desktop.
# LabelSmoothing.py
# https://www.aiuai.cn/aifarm1333.html 示例 3
# From: Github - NVIDIA/DeepLearningExamples/PyTorch/Classification
# smoothing.py
import torch
import torch.nn as nn
# 一般版本LabelSmoothing
class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.0):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
# main.py
import torch
import torch.nn as nn
from smoothing import LabelSmoothing
def add_parser_arguments(parser):
parser.add_argument('--label-smoothing',
default=0.0,
type=float,
metavar='S',
help='label smoothing')
def main(args):
loss = nn.CrossEntropyLoss
if args.label_smoothing > 0.0:
loss = lambda: LabelSmoothing(args.label_smoothing)
criterion = loss()
# LabelSmoothing loss
loss = criterion(output[:nl], target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment