Skip to content

Instantly share code, notes, and snippets.

@weiliu620
Last active July 19, 2023 10:30
Show Gist options
  • Save weiliu620/52d140b22685cf9552da4899e2160183 to your computer and use it in GitHub Desktop.
Save weiliu620/52d140b22685cf9552da4899e2160183 to your computer and use it in GitHub Desktop.
Dice coefficient loss function in PyTorch
def dice_loss(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
smooth = 1.
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(tflat * iflat)
B_sum = torch.sum(tflat * tflat)
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
@ucalyptus2
Copy link

@weiliu620 how to enable autograd fr this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment