Skip to content

Instantly share code, notes, and snippets.

@vashineyu
Created February 14, 2019 07:42
Show Gist options
  • Save vashineyu/1e18d65bd71a9e23d81ea0712634faf1 to your computer and use it in GitHub Desktop.
Save vashineyu/1e18d65bd71a9e23d81ea0712634faf1 to your computer and use it in GitHub Desktop.
"""Data generator for model
GetDataset: Get single data with next
Customized dataloader: Compose multiple dataset object together and put them into multi-processing flow
ver1. all patches were taken from single slide util N patches have been taken.
"""
import cv2
import os
import json
import sys
import numpy as np
sys.path.append("lab_mldl_tools")
from lab_mldl_tools.ndpwrapper import Slide_ndpread as Slide_reader
from matplotlib.path import Path
import matplotlib.patches as patches
import random
from random import shuffle
import multiprocessing as mp
from multiprocessing import Event
import tensorflow as tf
class GetDataset():
def __init__(self,
label_reference, label_list, class_id, n_classes,
f_inputs_preproc, image_size, boundary_criteria,
num_slides_hold=4, num_get_per_slide=1024, patch_stride=64, onehot=True, augmentation=None):
"""Claim a dataset object
Args:
- label_reference (dict): label definition file
- label_list (list): list of index [[slide_index1, target_index], [slide_index2, target_index], ...]
- class_id (int): y_label for this Dataset
- n_classes (int): total numbers of classes
- f_inputs_preproc (func): preprocessing function to the input array [w, h, c]
- image_size (tuple): input size
- boundary_criteria: criteria to check patch, str = ['soft', 'normal', 'hard']
- num_slides_hold (int): numbers of slide hold and random draw from
- num_get_per_slide (int): n patch crop from a slide before switch to next slide
- patch_stride (int): crop stride
- onehot (bool): do onehot encoding to label
- augmentation (func): imgaug aug function
Initalize: Generator of single object
Action: Call the next
"""
self.label_reference = label_reference
self.label_list = label_list
self.class_id = class_id
self.n_classes = n_classes
self.preproc = f_inputs_preproc
self.image_size = image_size
self.boundary_criteria = boundary_criteria
self.num_slides_hold = num_slides_hold
self.num_get_per_slide = num_get_per_slide
self.patch_stride = patch_stride
self.onehot = onehot
self.aug = augmentation
## Init action ##
shuffle(self.label_list)
self.holder = [wrap_default_values(nested_dict()) for i in range(num_slides_hold)]
def __len__(self):
return len(self.label_list) * self.num_get_per_slide
def __getitem__(self):
"""
Initalize object
"""
for i in range(self.num_slides_hold):
if counter_checker(self.holder[i]) == 0:
get_index = np.random.choice(len(self.label_list))
ind, txd = self.label_list[get_index]
slide_name = self.label_reference[ind]['ndpi_file'].replace("nas", "dataset")
self.holder[i]["slide"] = Slide_reader(slide_name=slide_name,
rle_file=None,
show_info=False)
self.holder[i]["croplist"] = self.compute_croplist(self.label_reference, ind, txd,
stride=self.patch_stride,
image_size=self.image_size,
boundary_crit=self.boundary_criteria)
"""Start pickup images"""
pickup_index = np.random.choice(self.num_slides_hold)
idx = self.holder[pickup_index]["counter"]
this_patch = np.array(self.holder[pickup_index]["slide"].get_patch_at_level(level = 0,
coord=(int(self.holder[pickup_index]["croplist"][idx]['x']),
int(self.holder[pickup_index]["croplist"][idx]['y'])),
sz = (self.holder[pickup_index]["croplist"][idx]['h'],
self.holder[pickup_index]["croplist"][idx]['w'])))
if self.aug is not None:
this_patch = self.aug.augment_image(this_patch)
this_patch = this_patch.astype(np.float32)
if self.preproc:
this_patch = self.preproc(this_patch)
this_class = tf.keras.utils.to_categorical(self.class_id, num_classes=self.n_classes)
self.holder[pickup_index]["counter"] += 1
n_patch_to_iteration = min(self.num_get_per_slide, len(self.holder[pickup_index]["croplist"]))
if self.holder[pickup_index]["counter"] == n_patch_to_iteration:
self.holder[pickup_index]["slide"].close()
self.holder[pickup_index] = wrap_default_values(nested_dict())
return this_patch, this_class
def __iter__(self):
return self
def __next__(self):
return self.__getitem__()
@staticmethod
def compute_croplist(label_reference, slide_index, target_index, stride, image_size, boundary_crit):
output = []
for _, segment_key in enumerate(label_reference[slide_index]['labels'][target_index]['data']):
polygon = Path(np.array(label_reference[slide_index]['labels'][target_index]['data'][segment_key][1]['segments']))
expand_radius = 0.
bounding_box = polygon.get_extents()
x_list = np.arange(bounding_box.x0, bounding_box.x1, stride)
y_list = np.arange(bounding_box.y0, bounding_box.y1, stride)
for x_left in x_list:
for y_top in y_list:
p_center = (int(x_left + (image_size/2)),
int(y_top + (image_size/2))
)
if boundary_crit is "soft":
if polygon.contains_point(p_center, radius= 0.):
output.append({'x':x_left, 'y':y_top, 'w':image_size, 'h':image_size})
else:
px, py = p_center
if boundary_crit is "normal":
margin = (image_size / 2) / 4 * 3
p_consider = [[px-margin, py],
[px+margin, py],
[px, py-margin],
[px, py+margin],
[px-margin,py-margin],
[px-margin, py+margin],
[px+margin, py-margin],
[px+margin, py+margin]]
#take_in_decision = polygon.contains_points(p_consider).sum() >= 1
take_in_decision = polygon.contains_points(p_consider).sum() == len(p_consider)
else:
margin = (image_size / 2) / 4 * 3
p_consider = [
[px-margin,py-margin],
[px-margin, py+margin],
[px+margin, py-margin],
[px+margin, py+margin]]
#take_in_decision = polygon.contains_points(p_consider).sum() == len(p_consider)
take_in_decision = polygon.contains_points(p_consider).sum() >= 1
if take_in_decision: # loose criteria
output.append({'x':x_left, 'y':y_top, 'w':image_size, 'h':image_size})
shuffle(output)
return output
class Customized_dataloader():
"""
1. Compose multiple generators together
2. Make this composed generator into multi-processing function
"""
def __init__(self, list_dataset, batch_size_per_dataset=16, queue_size=128, num_workers=0):
"""
Args:
- list_dataset: put generator object as list [gen1, gen2, ...]
- batch_size_per_dataset: bz for each generator (total_batch_size/n_class)
- queue_size: queue size
- num_workers: start n workers to get data
Action: Call with next
"""
self.list_dataset = list_dataset
self.batch_size_per_dataset = batch_size_per_dataset
self.sample_queue = mp.Queue(maxsize = queue_size)
self.jobs = num_workers
self.events = list()
self.workers = list()
for i in range(num_workers):
event = Event()
work = mp.Process(target = enqueue, args = (self.sample_queue, event, self.compose_data))
work.daemon = True
work.start()
self.events.append(event)
self.workers.append(work)
print("workers ready")
def __next__(self):
return self.sample_queue.get()
def compose_data(self):
while True:
imgs, labels = [], []
for z in range(self.batch_size_per_dataset):
data = [next(i) for i in self.list_dataset]
img, label = zip(*data)
imgs.append(np.array(img))
labels.append(np.array(label))
yield np.concatenate(imgs), np.concatenate(labels)
def stop_worker(self):
for t in self.events:
t.set()
for i, t in enumerate(self.workers):
t.join(timeout = 1)
print("all_worker_stop")
# ----- #
def enqueue(queue, stop, gen_func):
gen = gen_func()
while True:
if stop.is_set():
return
queue.put(next(gen))
def enqueue_nowait(queue, stop, gen_func):
gen = gen_func()
while True:
if stop.is_set():
return
#queue.put(next(gen))
queue.put_nowait(next(gen))
from collections import defaultdict
def nested_dict():
return defaultdict(nested_dict)
def wrap_default_values(dict_object):
dict_object["counter"] = 0
dict_object["name"] = str
dict_object["slide"] = object
dict_object["croplist"] = list
return dict_object
def counter_checker(dict_object):
return dict_object["counter"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment