Skip to content

Instantly share code, notes, and snippets.

@jimmy15923
Last active February 25, 2019 07:24
Show Gist options
  • Save jimmy15923/786d0ab6692cc229104ee9855b6cfb1a to your computer and use it in GitHub Desktop.
Save jimmy15923/786d0ab6692cc229104ee9855b6cfb1a to your computer and use it in GitHub Desktop.
import sys
sys.path.append("/mnt/deep-learning/usr/jimmy15923/keras_rcnn_family")
import json
import numpy as np
from collections import Counter
from scipy.ndimage.measurements import label
from slide_reader import *
import glob
import matplotlib.pyplot as plt
from mrcnn import utils
from yacs_config import get_cfg_defaults
with open ("/mnt/dataset/CGMHRENAL/CGMHRENAL_labels.json") as f:
labels = json.load(f)
with open("/mnt/dataset/CGMHRENAL/coordinate_list.json") as f:
coord_list = json.load(f)
image_path_dict = {}
for i in labels:
image_path_dict[i['name']] = i['ndpi_file']
############################################################
# Train Test Split
############################################################
TRAIN_SLIDE_IDS = []
for x in glob.glob("/mnt/dataset/CGMHRENAL/masks/*.json"):
if ("S17" in x) & (len(coord_list[os.path.basename(x)[:-10]]) > 0):
TRAIN_SLIDE_IDS.append(os.path.basename(x)[:-10])
VAL_SLIDE_IDS = []
for x in glob.glob("/mnt/dataset/CGMHRENAL/masks/*.json"):
if ("S18" in x) & (len(coord_list[os.path.basename(x)[:-10]]) > 0):
VAL_SLIDE_IDS.append(os.path.basename(x)[:-10])
TRAIN_SLIDE_IDS = [x for x in TRAIN_SLIDE_IDS if "HE" in x or "HE1" in x]
VAL_SLIDE_IDS = [x for x in VAL_SLIDE_IDS if "HE" in x or "HE1" in x]
print("Train:, ", len(TRAIN_SLIDE_IDS), ", Test: ", len(VAL_SLIDE_IDS))
############################################################
# Dataset
############################################################
class GloDataset(utils.Dataset):
"""Dataset class for Mask R-CNN training.
How to use GloDataset:
dataset_train = GloDataset(config=config)
dataset_train.load_dataset("train")
dataset_train.prepare()
Args:
config: yacs config
patch_size: size of patch
slide_level: which level to fetcg patches
random_shift: whether to random shift (Decide by subset)
id_instance_counter: We use slide id to load image of slide, but for each slide, we have plenty of instances.
the counter help to load_image with non-duplicate and as order of instances.
Returns:
type: description.
"""
def __init__(self, config):
self.config = config
self.patch_size = config.IMAGE_SHAPE[0]
self.slide_level = config.LEVEL
self.random_shift = False
self.id_instance_counter = Counter()
super(GloDataset, self).__init__()
def load_dataset(self, subset):
self.subset = subset
if self.subset == "train":
self.random_shift = True
# Add classes. We have one class.
# Naming the dataset glomerulus, and the class G
self.add_class("glomerulus", 1, "G")
# Add image infos (path and id)
if subset == "val":
image_ids = VAL_SLIDE_IDS
# Add images
for image_id in image_ids:
self.add_image(
"glomerulus",
image_id=image_id,
path=image_path_dict[image_id])
else:
# Get image ids from directory names
image_ids = TRAIN_SLIDE_IDS
# Add images
for image_id in image_ids:
self.add_image(
"glomerulus",
image_id=image_id,
path=image_path_dict[image_id])
def load_image(self, slide_id, center_shift=True):
# get slide_name from image_indo[int][id]
slide_name = self.image_info[slide_id]['id']
# Call a instance counter to orderly use all instance in one slide
self.id_instance_counter[slide_id] += 1
index = self.id_instance_counter[slide_id] % (len(coord_list[slide_name])) # index = which isntances
# TOP/LEFT point of this instance
coord = coord_list[slide_name][index][0]
if center_shift:
x_center, y_center = int((coord[0]+coord[2]) / 2), int((coord[1]+coord[3]) / 2)
coord = np.array((x_center, y_center)) - int(self.patch_size * (2**self.slide_level) / 2)
sli_reader = Slide_ndpread(self.image_info[slide_id]['path'],
rle_file="/mnt/dataset/CGMHRENAL/masks/{}_mask.json".format(slide_name),
show_info=False)
if self.random_shift == True:
coord = np.array(coord) + \
np.random.randint(-int(self.patch_size/2.1), int(self.patch_size/2.1)) * (2**self.slide_level)
coord[0] = np.clip(coord[0], 0, sli_reader.slide_info['Width'])
coord[1] = np.clip(coord[1], 0, sli_reader.slide_info['Height'])
self.image = sli_reader.get_patch_at_level(coord, sz=(self.patch_size, self.patch_size),
level=self.slide_level)
self.mask = sli_reader.get_mask_at_level(coord, sz=(self.patch_size, self.patch_size),
level=self.slide_level)
sli_reader.close()
return self.image
def load_mask(self, slide_id):
# Use skimage.label to check if multi-instance in one patch
structure = np.ones((3, 3), dtype=np.int)
labeled, ncomp = label(self.mask, structure)
mask = np.zeros(shape=(config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1], ncomp), dtype=np.bool)
for i in range(ncomp):
mask[:,:,i] = np.where(labeled==(i+1), 1, 0)
return mask, np.ones(mask.shape[2])
if __name__ == "__main__":
# Get config
config = get_cfg_defaults()
# Prepare our dataset
dataset_train = GloDataset(config)
dataset_train.load_dataset("train")
dataset_train.prepare()
# demo with slide_id=1
slide_id = 1
# Number of instances for this slide
n_instance = len(coord_list[dataset_train.image_info[slide_id]['id']])
for idx in range(n_instance)[:5]:
img = dataset_train.load_image(idx)
mask, ids = dataset_train.load_mask(idx)
print(idx)
plt.imshow(img)
plt.show()
plt.imshow(mask[:,:,0])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment