Skip to content

Instantly share code, notes, and snippets.

@tejaskhot
Created November 25, 2018 05:46
Show Gist options
  • Save tejaskhot/019f6babe9f358d75d8036fcd907d212 to your computer and use it in GitHub Desktop.
Save tejaskhot/019f6babe9f358d75d8036fcd907d212 to your computer and use it in GitHub Desktop.
pytorch implementation of focal loss which address the class imbalance issue for training one-stage dense detector
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment