Skip to content

Instantly share code, notes, and snippets.

@maxrohleder
Last active June 14, 2023 13:54
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxrohleder/61b0e01c5dc8b46e30e9b08013a22bba to your computer and use it in GitHub Desktop.
Save maxrohleder/61b0e01c5dc8b46e30e9b08013a22bba to your computer and use it in GitHub Desktop.
losses for biomedical image segmentation
import tensorflow as tf
class DiceBCELoss(tf.keras.losses.Loss):
def __init__(self, roi, smooth=1e-6, eps=1e-8, name='DiceBCE'):
""" A more stable surrogate for the dice metric
Sources:
[1] https://github.com/Project-MONAI/MONAI/issues/807
[2] https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch
Args:
roi: bool mask same size as y_pred and y_true
smooth: stability term added to denominator
eps: prevents devision by zero
name: name of the operation
"""
super(DiceBCELoss, self).__init__(name=name)
assert roi.dtype == bool
self.roi = tf.convert_to_tensor(roi)
self.smooth, self.eps = smooth, eps
self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=False) # no-logits -> probabilites
def call(self, y_true, y_pred):
# applying roi mask to exclude truncated area from loss
y_true = tf.where(self.roi, tf.cast(y_true, tf.float32), 0) # crop to roi
y_pred = tf.where(self.roi, tf.cast(y_pred, tf.float32), 0)
# calculating dice
intersection = tf.reduce_sum(y_pred * y_true) # A n B
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) # A + B
dice = (2 * intersection) / (union + self.smooth + self.eps) # removed the smooth term in numerator [1]
dice_loss = 1 - tf.clip_by_value(dice, self.eps, 1-self.eps)
# calculating BCE
BCE = self.bce(y_true, y_pred)
return dice_loss + BCE
import tensorflow as tf
class SoftDiceLossRoi(tf.keras.losses.Loss):
def __init__(self, roi, smooth=1e-4, eps=1e-8, name='SoftDiceRoi'):
super(SoftDiceLossRoi, self).__init__(name=name)
assert roi.dtype == bool
self.roi = tf.convert_to_tensor(roi)
self.smooth, self.eps = smooth, eps
def call(self, y_true, y_pred):
y_true = tf.where(self.roi, tf.cast(y_true, tf.float32), 0) # crop gt to roi
y_pred = tf.where(self.roi, tf.cast(y_pred, tf.float32), 0)
intersection = tf.reduce_sum(y_pred * y_true) # A n B
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) # A + B
dice = (2 * intersection + self.smooth) / (union + self.smooth + self.eps)
return 1 - tf.clip_by_value(dice, self.eps, 1-self.eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment