Skip to content

Instantly share code, notes, and snippets.

@MiniXC
Created November 10, 2020 16:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MiniXC/bd9133660f8bc6196a3e0cca5adbd2a6 to your computer and use it in GitHub Desktop.
Save MiniXC/bd9133660f8bc6196a3e0cca5adbd2a6 to your computer and use it in GitHub Desktop.
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
class FocalLoss(nn.Module):
""" see https://arxiv.org/abs/1708.02002
based on https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
"""
def __init__(self,
alpha: Tensor = None,
gamma: float = 2):
super().__init__()
self.gamma = gamma
self.nll_loss = nn.NLLLoss(weight=alpha, reduction='none')
def forward(self, x: Tensor, y: Tensor) -> Tensor:
# compute weighted cross entropy term: -alpha * log(pt)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
# get true class column from each row
all_rows = torch.arange(len(x))
log_pt = log_p[all_rows, y]
# compute focal term: (1 - pt)^gamma
pt = log_pt.exp()
focal_term = (1 - pt)**self.gamma
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
loss = loss.mean()
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment