Skip to content

Instantly share code, notes, and snippets.

@a-maumau
Last active October 2, 2018 10:07
Show Gist options
  • Save a-maumau/0a3b799790ee6a179105e949f545efd5 to your computer and use it in GitHub Desktop.
Save a-maumau/0a3b799790ee6a179105e949f545efd5 to your computer and use it in GitHub Desktop.
for img, mask, original_img in _trainval_loader:
batch_size = img.shape[0]
images = to_var(img, volatile=True)
#outputs = model.inference(images)
outputs = model(images)
outputs = F.softmax(outputs, dim=1)
# CRF ##############################################################################
crf_output = np.zeros(outputs.shape)
images = original_img.detach().numpy().astype(np.uint8)
for i, (image, prob_map) in enumerate(zip(images, outputs.detach().cpu().numpy())):
crf_output[i] = dense_crf(image, prob_map)
crf_outputs = crf_output
crf_outputs = torch.LongTensor(np.argmax(crf_outputs, axis=1))
outputs = torch.LongTensor(np.argmax(outputs.detach().cpu().numpy(), axis=1))
####################################################################################
pix_acc += sum(metric.pixel_accuracy(outputs.cpu().squeeze(1), mask, size_average=False))
precision = metric.precision(outputs.cpu().squeeze(1), mask, class_num=CLASS_NUM, size_average=False)
jaccard_index = metric.jaccard_index(outputs.cpu().squeeze(1), mask, class_num=CLASS_NUM, size_average=False)
for class_id in range(CLASS_NUM):
precision_class[class_id] += sum(precision["class_{}".format(str(class_id))])
jaccard_class[class_id] += sum(jaccard_index["class_{}".format(str(class_id))])
data_count_precision[class_id] += len(precision["class_{}".format(str(class_id))])
data_count_jaccard[class_id] += len(jaccard_index["class_{}".format(str(class_id))])
crf_pix_acc += sum(metric.pixel_accuracy(crf_outputs.cpu().squeeze(1), mask, size_average=False))
crf_precision = metric.precision(crf_outputs.cpu().squeeze(1), mask, class_num=CLASS_NUM, size_average=False)
crf_jaccard_index = metric.jaccard_index(crf_outputs.cpu().squeeze(1), mask, class_num=CLASS_NUM, size_average=False)
for class_id in range(CLASS_NUM):
crf_precision_class[class_id] += sum(crf_precision["class_{}".format(str(class_id))])
crf_jaccard_class[class_id] += sum(crf_jaccard_index["class_{}".format(str(class_id))])
# for taking mean.
data_count += batch_size
for n in range(batch_size):
pred = Image.fromarray(np.uint8((outputs[n].squeeze(0).numpy()*MUL_PIXEL)))
pred.save("{}_predict.png".format(os.path.join(args.save_dir, "{}_{}".format(load_num, n))))
pred = Image.fromarray(np.uint8((crf_outputs[n].squeeze(0).numpy()*MUL_PIXEL)))
pred.save("{}_predict_with_crf.png".format(os.path.join(args.save_dir, "{}_{}".format(load_num, n))))
pred = Image.fromarray(np.uint8((original_img[n].squeeze(0).numpy())))
pred.save("{}_input_original.png".format(os.path.join(args.save_dir, "{}_{}".format(load_num, n))))
pred = Image.fromarray(np.uint8((mask[n].squeeze(0).numpy())))
pred.save("{}_ground_truth.png".format(os.path.join(args.save_dir, "{}_{}".format(load_num, n))))
load_num += 1
# print result, oneline style seems to collapse the terminal printing.
#log_vals = [curr_iter]
# print result, oneline style seems to collapse the terminal printing.
tqdm.write("################")
tqdm.write("[#{}] trainval result".format(epoch+1))
tqdm.write("mean pix acc. : {:1.5f}".format(pix_acc/data_count))
for i in range(CLASS_NUM):
tqdm.write("mean precision : {:1.5f}".format(precision_class[i]/data_count_precision[i]))
for i in range(CLASS_NUM):
tqdm.write("mean jaccard index : {:1.5f}".format(jaccard_class[i]/data_count_jaccard[i]))
tqdm.write("crf mean pix acc. : {:1.5f}".format(crf_pix_acc/data_count))
for i in range(CLASS_NUM):
tqdm.write("crf mean precision : {:1.5f}".format(crf_precision_class[i]/data_count_precision[i]))
for i in range(CLASS_NUM):
tqdm.write("crf mean jaccard index : {:1.5f}".format(crf_jaccard_class[i]/data_count_jaccard[i]))
tqdm.write("################")
class PairRandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target_img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.vflip(img), F.vflip(target_img)
return img, target_img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment