Created
February 14, 2019 07:42
-
-
Save vashineyu/1e18d65bd71a9e23d81ea0712634faf1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Data generator for model | |
GetDataset: Get single data with next | |
Customized dataloader: Compose multiple dataset object together and put them into multi-processing flow | |
ver1. all patches were taken from single slide util N patches have been taken. | |
""" | |
import cv2 | |
import os | |
import json | |
import sys | |
import numpy as np | |
sys.path.append("lab_mldl_tools") | |
from lab_mldl_tools.ndpwrapper import Slide_ndpread as Slide_reader | |
from matplotlib.path import Path | |
import matplotlib.patches as patches | |
import random | |
from random import shuffle | |
import multiprocessing as mp | |
from multiprocessing import Event | |
import tensorflow as tf | |
class GetDataset(): | |
def __init__(self, | |
label_reference, label_list, class_id, n_classes, | |
f_inputs_preproc, image_size, boundary_criteria, | |
num_slides_hold=4, num_get_per_slide=1024, patch_stride=64, onehot=True, augmentation=None): | |
"""Claim a dataset object | |
Args: | |
- label_reference (dict): label definition file | |
- label_list (list): list of index [[slide_index1, target_index], [slide_index2, target_index], ...] | |
- class_id (int): y_label for this Dataset | |
- n_classes (int): total numbers of classes | |
- f_inputs_preproc (func): preprocessing function to the input array [w, h, c] | |
- image_size (tuple): input size | |
- boundary_criteria: criteria to check patch, str = ['soft', 'normal', 'hard'] | |
- num_slides_hold (int): numbers of slide hold and random draw from | |
- num_get_per_slide (int): n patch crop from a slide before switch to next slide | |
- patch_stride (int): crop stride | |
- onehot (bool): do onehot encoding to label | |
- augmentation (func): imgaug aug function | |
Initalize: Generator of single object | |
Action: Call the next | |
""" | |
self.label_reference = label_reference | |
self.label_list = label_list | |
self.class_id = class_id | |
self.n_classes = n_classes | |
self.preproc = f_inputs_preproc | |
self.image_size = image_size | |
self.boundary_criteria = boundary_criteria | |
self.num_slides_hold = num_slides_hold | |
self.num_get_per_slide = num_get_per_slide | |
self.patch_stride = patch_stride | |
self.onehot = onehot | |
self.aug = augmentation | |
## Init action ## | |
shuffle(self.label_list) | |
self.holder = [wrap_default_values(nested_dict()) for i in range(num_slides_hold)] | |
def __len__(self): | |
return len(self.label_list) * self.num_get_per_slide | |
def __getitem__(self): | |
""" | |
Initalize object | |
""" | |
for i in range(self.num_slides_hold): | |
if counter_checker(self.holder[i]) == 0: | |
get_index = np.random.choice(len(self.label_list)) | |
ind, txd = self.label_list[get_index] | |
slide_name = self.label_reference[ind]['ndpi_file'].replace("nas", "dataset") | |
self.holder[i]["slide"] = Slide_reader(slide_name=slide_name, | |
rle_file=None, | |
show_info=False) | |
self.holder[i]["croplist"] = self.compute_croplist(self.label_reference, ind, txd, | |
stride=self.patch_stride, | |
image_size=self.image_size, | |
boundary_crit=self.boundary_criteria) | |
"""Start pickup images""" | |
pickup_index = np.random.choice(self.num_slides_hold) | |
idx = self.holder[pickup_index]["counter"] | |
this_patch = np.array(self.holder[pickup_index]["slide"].get_patch_at_level(level = 0, | |
coord=(int(self.holder[pickup_index]["croplist"][idx]['x']), | |
int(self.holder[pickup_index]["croplist"][idx]['y'])), | |
sz = (self.holder[pickup_index]["croplist"][idx]['h'], | |
self.holder[pickup_index]["croplist"][idx]['w']))) | |
if self.aug is not None: | |
this_patch = self.aug.augment_image(this_patch) | |
this_patch = this_patch.astype(np.float32) | |
if self.preproc: | |
this_patch = self.preproc(this_patch) | |
this_class = tf.keras.utils.to_categorical(self.class_id, num_classes=self.n_classes) | |
self.holder[pickup_index]["counter"] += 1 | |
n_patch_to_iteration = min(self.num_get_per_slide, len(self.holder[pickup_index]["croplist"])) | |
if self.holder[pickup_index]["counter"] == n_patch_to_iteration: | |
self.holder[pickup_index]["slide"].close() | |
self.holder[pickup_index] = wrap_default_values(nested_dict()) | |
return this_patch, this_class | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return self.__getitem__() | |
@staticmethod | |
def compute_croplist(label_reference, slide_index, target_index, stride, image_size, boundary_crit): | |
output = [] | |
for _, segment_key in enumerate(label_reference[slide_index]['labels'][target_index]['data']): | |
polygon = Path(np.array(label_reference[slide_index]['labels'][target_index]['data'][segment_key][1]['segments'])) | |
expand_radius = 0. | |
bounding_box = polygon.get_extents() | |
x_list = np.arange(bounding_box.x0, bounding_box.x1, stride) | |
y_list = np.arange(bounding_box.y0, bounding_box.y1, stride) | |
for x_left in x_list: | |
for y_top in y_list: | |
p_center = (int(x_left + (image_size/2)), | |
int(y_top + (image_size/2)) | |
) | |
if boundary_crit is "soft": | |
if polygon.contains_point(p_center, radius= 0.): | |
output.append({'x':x_left, 'y':y_top, 'w':image_size, 'h':image_size}) | |
else: | |
px, py = p_center | |
if boundary_crit is "normal": | |
margin = (image_size / 2) / 4 * 3 | |
p_consider = [[px-margin, py], | |
[px+margin, py], | |
[px, py-margin], | |
[px, py+margin], | |
[px-margin,py-margin], | |
[px-margin, py+margin], | |
[px+margin, py-margin], | |
[px+margin, py+margin]] | |
#take_in_decision = polygon.contains_points(p_consider).sum() >= 1 | |
take_in_decision = polygon.contains_points(p_consider).sum() == len(p_consider) | |
else: | |
margin = (image_size / 2) / 4 * 3 | |
p_consider = [ | |
[px-margin,py-margin], | |
[px-margin, py+margin], | |
[px+margin, py-margin], | |
[px+margin, py+margin]] | |
#take_in_decision = polygon.contains_points(p_consider).sum() == len(p_consider) | |
take_in_decision = polygon.contains_points(p_consider).sum() >= 1 | |
if take_in_decision: # loose criteria | |
output.append({'x':x_left, 'y':y_top, 'w':image_size, 'h':image_size}) | |
shuffle(output) | |
return output | |
class Customized_dataloader(): | |
""" | |
1. Compose multiple generators together | |
2. Make this composed generator into multi-processing function | |
""" | |
def __init__(self, list_dataset, batch_size_per_dataset=16, queue_size=128, num_workers=0): | |
""" | |
Args: | |
- list_dataset: put generator object as list [gen1, gen2, ...] | |
- batch_size_per_dataset: bz for each generator (total_batch_size/n_class) | |
- queue_size: queue size | |
- num_workers: start n workers to get data | |
Action: Call with next | |
""" | |
self.list_dataset = list_dataset | |
self.batch_size_per_dataset = batch_size_per_dataset | |
self.sample_queue = mp.Queue(maxsize = queue_size) | |
self.jobs = num_workers | |
self.events = list() | |
self.workers = list() | |
for i in range(num_workers): | |
event = Event() | |
work = mp.Process(target = enqueue, args = (self.sample_queue, event, self.compose_data)) | |
work.daemon = True | |
work.start() | |
self.events.append(event) | |
self.workers.append(work) | |
print("workers ready") | |
def __next__(self): | |
return self.sample_queue.get() | |
def compose_data(self): | |
while True: | |
imgs, labels = [], [] | |
for z in range(self.batch_size_per_dataset): | |
data = [next(i) for i in self.list_dataset] | |
img, label = zip(*data) | |
imgs.append(np.array(img)) | |
labels.append(np.array(label)) | |
yield np.concatenate(imgs), np.concatenate(labels) | |
def stop_worker(self): | |
for t in self.events: | |
t.set() | |
for i, t in enumerate(self.workers): | |
t.join(timeout = 1) | |
print("all_worker_stop") | |
# ----- # | |
def enqueue(queue, stop, gen_func): | |
gen = gen_func() | |
while True: | |
if stop.is_set(): | |
return | |
queue.put(next(gen)) | |
def enqueue_nowait(queue, stop, gen_func): | |
gen = gen_func() | |
while True: | |
if stop.is_set(): | |
return | |
#queue.put(next(gen)) | |
queue.put_nowait(next(gen)) | |
from collections import defaultdict | |
def nested_dict(): | |
return defaultdict(nested_dict) | |
def wrap_default_values(dict_object): | |
dict_object["counter"] = 0 | |
dict_object["name"] = str | |
dict_object["slide"] = object | |
dict_object["croplist"] = list | |
return dict_object | |
def counter_checker(dict_object): | |
return dict_object["counter"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment