Skip to content

Instantly share code, notes, and snippets.

@weiliu620
Last active July 19, 2023 10:30
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save weiliu620/52d140b22685cf9552da4899e2160183 to your computer and use it in GitHub Desktop.
Save weiliu620/52d140b22685cf9552da4899e2160183 to your computer and use it in GitHub Desktop.
Dice coefficient loss function in PyTorch
def dice_loss(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
smooth = 1.
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(tflat * iflat)
B_sum = torch.sum(tflat * tflat)
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
@youngfly11
Copy link

I think there is some wrong in line 17. It should be $A_sum = torch.sum(iflat*iflat)$.
And i have a question, Do you have any constrain for pred, like$sigmoid(pred)$?

@moulei007
Copy link

I think there is some wrong in line 17. It should be $A_sum = torch.sum(iflat*iflat)$.
And i have a question, Do you have any constrain for pred, like$sigmoid(pred)$?

The author's code is correct, and there is a detailed discussion here: Dice Loss PR #1249

@JoHof
Copy link

JoHof commented Feb 22, 2019

As mentioned by @youngfly11. Line 17 seems to be wrong

@maxmo2009
Copy link

Yes line 17 is wrong

@sbelharbi
Copy link

a typo at line 17. It should be something like this:

A_sum = torch.sum(iflat * iflat) 

@nabsabraham
Copy link

hey guys, i understand how this can be generalized to multiple classes that have been one-hot encoded - however in pytorch, gt classes for segmentation don't have to be one-hot encoded so how does everyone go about using this gdl for segmentation?

@oliverguhr
Copy link

oliverguhr commented Sep 12, 2019

This din't work for me. I ended up with this code:

def dice_loss(pred,target):
    numerator = 2 * torch.sum(pred * target)
    denominator = torch.sum(pred + target)
    return 1 - (numerator + 1) / (denominator + 1)

@KhrystynaFaryna
Copy link

This is wrong:
A_sum = torch.sum(tflat * iflat)
B_sum = torch.sum(tflat * tflat)
You want to find the sum of pixels/voxels in target and prediction separately for the denominator, not a union. Should be smth like:
A_sum = torch.sum(tflat)
B_sum = torch.sum(iflat)
But pay attention to the dimensions which you are summing it along.

@ucalyptus2
Copy link

@weiliu620 how to enable autograd fr this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment