Skip to content

Instantly share code, notes, and snippets.

@edraizen
Last active February 9, 2018 15:21
Show Gist options
  • Save edraizen/d29b2a3f46e40e0de945ba098604b942 to your computer and use it in GitHub Desktop.
Save edraizen/d29b2a3f46e40e0de945ba098604b942 to your computer and use it in GitHub Desktop.
from itertools import groupby
from torch.nn.modules.loss import _Loss
class DiceLoss(_Loss):
def __init__(self, size_average=True, smooth=1.):
super(DiceLoss, self).__init__(size_average)
self.smooth = smooth
def forward(self, input, target, locations, weights=None):
if self.size_average:
return -self.dice_coef_samples(input, target, locations, weights)
return -self.dice_coef_batch(input, target, weights)
def dice_coef_batch(self, input, target, weights=None):
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
dice = ((2. * intersection + self.smooth) / ((iflat.sum() + tflat.sum() + self.smooth)))
if weights is not None:
dice *= weights
return dice
def dice_coef_samples(self, input, target, locations, weights=None):
samples = locations[:, 3]
previous_row = 0
dice = None
num_samples = samples[-1]+1
if weight is not None:
use_sample_weights = isinstance(weight, (list, tuple))
if use_sample_weights:
assert use_sample_weights and len(weight) == num_samples
for i, sample in groupby(enumerate(samples), key=lambda x:x[1]):
for voxel_end in sample: pass
batch_predictions = input[previous_row:voxel_end[0]+1]
target_values = target[previous_row:voxel_end[0]+1]
previous_row = voxel_end[0]
iflat = batch_predictions.view(-1)
tflat = target_values.view(-1)
intersection = (iflat * tflat).sum()
dice_val = ((2. * intersection + self.smooth) / ((iflat.sum() + tflat.sum() + self.smooth)))
if use_sample_weights:
dice_val *= weights[i]
if dice is None:
dice = dice_val
else:
dice += dice_val
if weights is not None and not use_sample_weights:
dice_val *= weights
return dice/float(num_samples)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment