Skip to content

Instantly share code, notes, and snippets.

@samson-wang
Last active April 2, 2024 11:56
Show Gist options
  • Save samson-wang/e5cee676f2ae97795356d9c340d1ec7f to your computer and use it in GitHub Desktop.
Save samson-wang/e5cee676f2ae97795356d9c340d1ec7f to your computer and use it in GitHub Desktop.
A really simple pytorch implementation of focal loss for both sigmoid and softmax predictions.
import torch
from torch.nn.functional import log_softmax
def sigmoid_focal_loss(logits, target, gamma=2., alpha=0.25):
num_classes = logits.shape[1]
dtype = target.dtype
device = target.device
class_range = torch.arange(0, num_classes, dtype=dtype, device=device).unsqueeze(0)
t = target.unsqueeze(1)
p = torch.sigmoid(logits)
term1 = (1 - p) ** gamma * torch.log(p)
term2 = p ** gamma * torch.log(1 - p)
return torch.mean(-(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha))
def softmax_focal_loss(x, target, gamma=2., alpha=0.25):
n = x.shape[0]
device = target.device
range_n = torch.arange(0, n, dtype=torch.int64, device=device)
pos_num = float(x.shape[1])
p = torch.softmax(x, dim=1)
p = p[range_n, target]
loss = -(1-p)**gamma*alpha*torch.log(p)
return torch.sum(loss) / pos_num
@opeide
Copy link

opeide commented Nov 28, 2023

You have swapped your alpha and gamma values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment