Skip to content

Instantly share code, notes, and snippets.

@Varal7
Created July 19, 2019 20:33
Show Gist options
  • Save Varal7/d3c5fe0a82d5d793997f4e210e165764 to your computer and use it in GitHub Desktop.
Save Varal7/d3c5fe0a82d5d793997f4e210e165764 to your computer and use it in GitHub Desktop.
def binary_focal_loss_with_logits(input, target, gamma, reduction="mean"):
t = target.float()
p = torch.clamp((2 * t - 1) * torch.sigmoid(input) + (1 - t), min=1e-20)
loss = - ( 1 - p) ** gamma * p.log()
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
else:
raise ValueError("Unkown reduction {}".format(reduction))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment