Skip to content

Instantly share code, notes, and snippets.

@Kulbear
Created February 17, 2019 20:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kulbear/caacf89160d80502c25446b8ad36da76 to your computer and use it in GitHub Desktop.
Save Kulbear/caacf89160d80502c25446b8ad36da76 to your computer and use it in GitHub Desktop.
class FocalLoss(nn.Module):
def __init__(self, gamma=2):
super().__init__()
self.gamma = gamma
def forward(self, logit, target):
target = target.float()
max_val = (-logit).clamp(min=0)
loss = logit - logit * target + max_val + \
((-max_val).exp() + (-logit - max_val).exp()).log()
invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
loss = (invprobs * self.gamma).exp() * loss
if len(loss.size())==2:
loss = loss.sum(dim=1)
return loss.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment