Skip to content

Instantly share code, notes, and snippets.

@burrussmp
Created May 5, 2020 20:49
Show Gist options
  • Save burrussmp/baef15f13affff1cd78883bae7ecf923 to your computer and use it in GitHub Desktop.
Save burrussmp/baef15f13affff1cd78883bae7ecf923 to your computer and use it in GitHub Desktop.
Test segmentation model (U-Net) using the dice coefficient and return the average, median, and standard deviation dice scores. The close to 1 the average dice score, the better the model.
def dice_loss(input, target):
smooth = 1.
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) /
(iflat.sum() + tflat.sum() + smooth))
testing_list = get_list('test')
test_dataset = ListDataset(testing_list, load_data)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=1)
# load model
pathToModel = os.path.join(BASEDIR,'weights2.pt')
print('Loading model')
model = UNet2D()
model.cuda()
model.load_state_dict(torch.load(pathToModel))
model.eval()
# test
dice_score = np.zeros((len(test_loader)))
with torch.no_grad():
for batch_idx, loaded in enumerate(test_loader):
data = loaded['src'].to(device)
target = loaded['target'].to(device)
output = model(data.float()) # collect the outputs
pt = np.squeeze(output.max(1)[1].type(torch.int32).cpu().data.numpy())
gt = np.squeeze(target.max(1)[1].type(torch.int32).cpu().data.numpy())
dice_score[batch_idx] = 1.0 - dice_loss(output,target).cpu().data.numpy()
if batch_idx % 50 == 0:
print('Progress: {:.2f}%'.format(batch_idx/len(test_loader)))
print('Average Dice Score',np.mean(dice_score))
print('Std Dice Score',np.std(dice_score))
print('Median Dice Score',np.median(dice_score))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment