Skip to content

Instantly share code, notes, and snippets.

@f1recracker
Last active September 3, 2023 09:11
Show Gist options
  • Save f1recracker/0f564fd48f15a58f4b92b3eb3879149b to your computer and use it in GitHub Desktop.
Save f1recracker/0f564fd48f15a58f4b92b3eb3879149b to your computer and use it in GitHub Desktop.
PyTorch implementation of focal loss that is drop-in compatible with torch.nn.CrossEntropyLoss
# pylint: disable=arguments-differ
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.CrossEntropyLoss):
''' Focal loss for classification tasks on imbalanced datasets '''
def __init__(self, gamma, alpha=None, ignore_index=-100, reduction='none'):
super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
self.reduction = reduction
self.gamma = gamma
def forward(self, input_, target):
cross_entropy = super().forward(input_, target)
# Temporarily mask out ignore index to '0' for valid gather-indices input.
# This won't contribute final loss as the cross_entropy contribution
# for these would be zero.
target = target * (target != self.ignore_index).long()
input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1))
loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
return torch.mean(loss) if self.reduction == 'mean'
else torch.sum(loss) if self.reduction == 'sum'
else loss
@YoelShoshan
Copy link

Perhaps there is an issue with this line, as CrossEntropyLoss returns a scalar by default (reduction='mean' by default) so maybe you miss the chance of modifying individual samples losses based on the focal loss?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment