-
-
Save weiliu620/52d140b22685cf9552da4899e2160183 to your computer and use it in GitHub Desktop.
def dice_loss(pred, target): | |
"""This definition generalize to real valued pred and target vector. | |
This should be differentiable. | |
pred: tensor with first dimension as batch | |
target: tensor with first dimension as batch | |
""" | |
smooth = 1. | |
# have to use contiguous since they may from a torch.view op | |
iflat = pred.contiguous().view(-1) | |
tflat = target.contiguous().view(-1) | |
intersection = (iflat * tflat).sum() | |
A_sum = torch.sum(tflat * iflat) | |
B_sum = torch.sum(tflat * tflat) | |
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) | |
I think there is some wrong in line 17. It should be
$A_sum = torch.sum(iflat*iflat)$ .
And i have a question, Do you have any constrain for pred, like$sigmoid(pred)$?
The author's code is correct, and there is a detailed discussion here: Dice Loss PR #1249
As mentioned by @youngfly11. Line 17 seems to be wrong
Yes line 17 is wrong
a typo at line 17. It should be something like this:
A_sum = torch.sum(iflat * iflat)
hey guys, i understand how this can be generalized to multiple classes that have been one-hot encoded - however in pytorch, gt classes for segmentation don't have to be one-hot encoded so how does everyone go about using this gdl for segmentation?
This din't work for me. I ended up with this code:
def dice_loss(pred,target):
numerator = 2 * torch.sum(pred * target)
denominator = torch.sum(pred + target)
return 1 - (numerator + 1) / (denominator + 1)
This is wrong:
A_sum = torch.sum(tflat * iflat)
B_sum = torch.sum(tflat * tflat)
You want to find the sum of pixels/voxels in target and prediction separately for the denominator, not a union. Should be smth like:
A_sum = torch.sum(tflat)
B_sum = torch.sum(iflat)
But pay attention to the dimensions which you are summing it along.
@weiliu620 how to enable autograd fr this?
I think there is some wrong in line 17. It should be$A_sum = torch.sum(iflat*iflat)$ .
And i have a question, Do you have any constrain for pred, like$sigmoid(pred)$?