Skip to content

Instantly share code, notes, and snippets.

@rsomani95
Last active September 2, 2020 02:07
Show Gist options
  • Save rsomani95/4618a07d914dd94174c96ab8991cb641 to your computer and use it in GitHub Desktop.
Save rsomani95/4618a07d914dd94174c96ab8991cb641 to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch
class FocalLoss(nn.Module):
# implementation adapted from https://amaarora.github.io/2020/06/29/FocalLoss.html
# paper: https://arxiv.org/abs/1708.02002
"Focal Loss"
def __init__(self, alpha=.25, gamma=2, reduction='none', pos_weight=None):
super(FocalLoss, self).__init__()
self.alpha = torch.tensor([alpha, 1-alpha])#.cuda()
self.gamma = gamma
self.reduction = reduction
self.pos_weight = pos_weight
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets,
reduction=self.reduction,
pos_weight=self.pos_weight)
targets = targets.type(torch.long)
at = self.alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss)
F_loss = at*(1-pt)**self.gamma * BCE_loss
return F_loss.mean()
### Multi-Label Classification Example
batch_size = 64
vocab_size = 10
pos_weight = torch.ones(vocab_size)
yb, preds = torch.rand(64,vocab_size), torch.rand(64,vocab_size)
focal_loss = FocalLoss(reduction='mean', pos_weight=pos_weight)
focal_loss_BCE = FocalLoss(alpha=1, gamma=0,
reduction='mean', pos_weight=pos_weight)
BCEloss = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight)
focal_loss(yb,preds)
assert focal_loss_BCE(yb,preds) == BCEloss(yb,preds)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment