Last active
September 3, 2023 09:11
-
-
Save f1recracker/0f564fd48f15a58f4b92b3eb3879149b to your computer and use it in GitHub Desktop.
PyTorch implementation of focal loss that is drop-in compatible with torch.nn.CrossEntropyLoss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pylint: disable=arguments-differ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FocalLoss(nn.CrossEntropyLoss): | |
''' Focal loss for classification tasks on imbalanced datasets ''' | |
def __init__(self, gamma, alpha=None, ignore_index=-100, reduction='none'): | |
super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none') | |
self.reduction = reduction | |
self.gamma = gamma | |
def forward(self, input_, target): | |
cross_entropy = super().forward(input_, target) | |
# Temporarily mask out ignore index to '0' for valid gather-indices input. | |
# This won't contribute final loss as the cross_entropy contribution | |
# for these would be zero. | |
target = target * (target != self.ignore_index).long() | |
input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)) | |
loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy | |
return torch.mean(loss) if self.reduction == 'mean' | |
else torch.sum(loss) if self.reduction == 'sum' | |
else loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Perhaps there is an issue with this line, as CrossEntropyLoss returns a scalar by default (reduction='mean' by default) so maybe you miss the chance of modifying individual samples losses based on the focal loss?