Skip to content

Instantly share code, notes, and snippets.

@wassname
Created September 26, 2016 08:32
Show Gist options
  • Star 42 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save wassname/7793e2058c5c9dacb5212c0ac0b18a8a to your computer and use it in GitHub Desktop.
Save wassname/7793e2058c5c9dacb5212c0ac0b18a8a to your computer and use it in GitHub Desktop.
dice_loss_for_keras
"""
Here is a dice loss for keras which is smoothed to approximate a linear (L1) loss.
It ranges from 1 to 0 (no error), and returns results similar to binary crossentropy
"""
# define custom loss and metric functions
from keras import backend as K
def dice_coef(y_true, y_pred, smooth=1):
"""
Dice = (2*|X & Y|)/ (|X|+ |Y|)
= 2*sum(|A*B|)/(sum(A^2)+sum(B^2))
ref: https://arxiv.org/pdf/1606.04797v1.pdf
"""
intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)
def dice_coef_loss(y_true, y_pred):
return 1-dice_coef(y_true, y_pred)
# Test
y_true = np.array([[0,0,1,0],[0,0,1,0],[0,0,1.,0.]])
y_pred = np.array([[0,0,0.9,0],[0,0,0.1,0],[1,1,0.1,1.]])
r = dice_coef_loss(
K.theano.shared(y_true),
K.theano.shared(y_pred),
).eval()
print('dice_coef_loss',r)
r = keras.objectives.binary_crossentropy(
K.theano.shared(y_true),
K.theano.shared(y_pred),
).eval()
print('binary_crossentropy',r)
print('binary_crossentropy_scaled',r/r.max())
# TYPE |Almost_right |half right |all_wrong
# dice_coef_loss [ 0.00355872 0.40298507 0.76047904]
# binary_crossentropy [ 0.0263402 0.57564635 12.53243514]
@wassname
Copy link
Author

I recommend people use jaccard_coef_loss instead. It's very similar but it provides a loss gradient even near 0, leading to better accuracy.

@mrgloom
Copy link

mrgloom commented Dec 13, 2018

Why (2*|X & Y|)/ (|X|+ |Y|) = 2*sum(|A*B|)/(sum(A^2)+sum(B^2))?

@asciidiego
Copy link

asciidiego commented Feb 3, 2019

The definition of the union of X and Y is the sum of X and Y (as you correctly said) MINUS the intersection!

That is, line 17 should be:

return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) - intersection + smooth)

@MathisMohand
Copy link

I'm sorry to answer to such an old topic but I have to correct you here @diegovincent. You don't have to do - intersection, as you already multiply it by 2 in the numerator, as labels (y) are ones/zeros, the union can give 2s per moment so you can get a perfect match if intersection == union. There is no need to substract the intersection.

@amartya-k
Copy link

Hi @wassname can you please explain why you're taking sum of the squares in the denominator? By that I mean this

(sum(A^2)+sum(B^2)

The other implementations don't take the sum of squares, instead only the sum. Please explain. You can check this out for reference https://lars76.github.io/neural-networks/object-detection/losses-for-segmentation/

@wassname
Copy link
Author

wassname commented Nov 3, 2019

@amartya-k it's based section 3 of the paper I referenced in the gist which uses squared on the bottom.

Why the difference? Those vertical bars around the numbers are vector norm. It's possible that the implementations you are comparing use different vector norms, I went with the one in the paper since it had demonstrated performance. The paper is also listing the equation for dice loss, not the dice equation so it may be the whole thing is squared for greater stability. I guess you will have to dig deeper for the answer.

I now use Jaccard loss, or IoU loss, or Focal Loss, or generalised dice loss instead of this gist.

@amartya-k
Copy link

@amartya-k it's based section 3 of the paper I referenced in the gist which uses squared on the bottom.

Why the difference? Those vertical bars around the numbers are vector norm. It's possible that the implementations you are comparing use different vector norms, I went with the one in the paper since it had demonstrated performance. The paper is also listing the equation for dice loss, not the dice equation so it may be the whole thing is squared for greater stability. I guess you will have to dig deeper for the answer.

I now use Jaccard loss, or IoU loss, or Focal Loss, or generalised dice loss instead of this gist.

Thanks :)

@Tombery1
Copy link

Tombery1 commented Jun 2, 2022

Please,what is the correct implementation of the dice coefficient

def dice_coef1(y_true, y_pred, smooth=1):
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
  union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
  dice = K.mean((2. * intersection + smooth)/(union + smooth), axis=0)
  return dice

Gives me the following result = 0.85

or

def dice_coef2(target, prediction, smooth=1):
    numerator = 2.0 * K.sum(target * prediction) + smooth
    denominator = K.sum(target) + K.sum(prediction) + smooth
    coef = numerator / denominator

    return coef

Gives me the following result : 0.94

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment