Skip to content

Instantly share code, notes, and snippets.

@scottjmaddox
Created July 14, 2024 02:50
Show Gist options
  • Save scottjmaddox/adde73e8ff948aa3613154531e7c270f to your computer and use it in GitHub Desktop.
Save scottjmaddox/adde73e8ff948aa3613154531e7c270f to your computer and use it in GitHub Desktop.
Pytorch implementation of the outlier suppression loss described in the paper "Improving generalization by loss modification" by Michael Tetelman
# Pytorch implementation of the outlier suppression loss described in the paper
# "Improving generalization by loss modification" by Michael Tetelman
# https://openreview.net/forum?id=vHOO1lxggJ
from torch import nn
import torch.functional as F
def outlier_supression_loss(input, target):
return F.softplus(F.nll_loss(F.log_softmax(input, dim=-1), target, reduction='none')).mean()
class OutlierSuppressionLoss(nn.Module):
def forward(self, input, target):
return outlier_supression_loss(input, target)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment