Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created July 21, 2018 11:02
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save snakers4/5739ade67e54230aba9bd8a468a3b7be to your computer and use it in GitHub Desktop.
Save snakers4/5739ade67e54230aba9bd8a468a3b7be to your computer and use it in GitHub Desktop.
Multi class classification focal loss
import torch
import torch.nn as nn
import torch.nn.functional as F
# Focal loss implementation inspired by
# https://github.com/c0nn3r/RetinaNet/blob/master/focal_loss.py
# https://github.com/doiken23/pytorch_toolbox/blob/master/focalloss2d.py
class MultiClassBCELoss(nn.Module):
def __init__(self,
use_weight_mask=False,
use_focal_weights=False,
focus_param=2,
balance_param=0.25
):
super().__init__()
self.use_weight_mask = use_weight_mask
self.nll_loss = nn.BCEWithLogitsLoss()
self.use_focal_weights = use_focal_weights
self.focus_param = focus_param
self.balance_param = balance_param
def forward(self,
outputs,
targets,
weights):
# inputs and targets are assumed to be BatchxClasses
assert len(outputs.shape) == len(targets.shape)
assert outputs.size(0) == targets.size(0)
assert outputs.size(1) == targets.size(1)
# weights are assumed to be BatchxClasses
assert outputs.size(0) == weights.size(0)
assert outputs.size(1) == weights.size(1)
if self.use_weight_mask:
bce_loss = F.binary_cross_entropy_with_logits(input=outputs,
target=targets,
weight=weights)
else:
bce_loss = self.nll_loss(input=outputs,
target=targets)
if self.use_focal_weights:
logpt = - bce_loss
pt = torch.exp(logpt)
focal_loss = -((1 - pt) ** self.focus_param) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss
else:
return bce_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment