Created
October 3, 2018 14:41
-
-
Save thorstenwagner/8033f43b99d1d3a1a6a31b054d91e7fc to your computer and use it in GitHub Desktop.
PatchwiseBatchGenerator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import imagereader | |
from .augmentation import Augmentation | |
import cv2 | |
import copy | |
from .utils import BoundBox, bbox_iou | |
import numpy as np | |
from keras.utils import Sequence | |
import threading | |
import multiprocessing | |
from multiprocessing import sharedctypes | |
import time | |
class PatchwiseBatchGenerator(Sequence): | |
def __init__(self, images, | |
config, | |
patch_assigments, | |
shuffle=True, | |
jitter=True, | |
norm=None, | |
overlap=100, | |
name=None, | |
train_times=1): | |
self.myname = name | |
self.images = images | |
self.imgdepth = 3 | |
self.config = config | |
self.shuffle = shuffle | |
self.jitter = jitter | |
self.norm = norm | |
self.counter = 0 | |
self.patch_assigments = patch_assigments | |
self.overlap = overlap | |
self.anchors = [BoundBox(0, 0, config['ANCHORS'][2 * i], config['ANCHORS'][2 * i + 1]) for i in | |
range(int(len(config['ANCHORS']) // 2))] | |
self.train_times = train_times | |
# Evaluate image depth | |
self.patch_assigments_3d = np.repeat(self.patch_assigments[:, :, np.newaxis], len(self.images), axis=2) | |
result = np.ctypeslib.as_ctypes(self.patch_assigments_3d) | |
self.patch_assigments_3d = sharedctypes.RawArray(result._type_, result) | |
self.patch_assigments_3d = np.ctypeslib.as_array(self.patch_assigments_3d) | |
self.lock = multiprocessing.Lock() | |
img_first = imagereader.image_read(self.images[0]['filename']) | |
if len(img_first.shape) == 2: | |
self.imgdepth = 1 | |
elif np.all(img_first[:, :, 0] == img_first[:, :, 1]) and np.all(img_first[:, :, 0] == img_first[:, :, 2]): | |
self.imgdepth = 1 | |
if shuffle: | |
np.random.shuffle(self.images) | |
def __len__(self): | |
length = int(np.ceil(float(np.sum(self.patch_assigments) * len(self.images))/self.config['BATCH_SIZE'] ))* self.train_times | |
return length | |
def __getitem__(self, idx): | |
np.random.seed() | |
instance_count = 0 | |
x_batch = np.zeros( | |
(self.config['BATCH_SIZE'], self.config['IMAGE_W'], self.config['IMAGE_W'], self.imgdepth)) # input images | |
b_batch = np.zeros((self.config['BATCH_SIZE'], 1, 1, 1, self.config['TRUE_BOX_BUFFER'], | |
4)) # list of self.config['TRUE_self.config['BOX']_BUFFER'] GT boxes | |
y_batch = np.zeros((self.config['BATCH_SIZE'], self.config['GRID_H'], self.config['GRID_W'], self.config['BOX'], | |
4 + 1 + self.config['CLASS'])) # desired network output | |
aug_times = [] | |
while instance_count < self.config['BATCH_SIZE']: | |
# If no unselected patches are available but still intances left then reset patch_assigments_3d | |
#print np.sum(np.nonzero(self.patch_assigments_3d[:, :, ])) | |
with self.lock: | |
if np.sum(self.patch_assigments_3d[:, :, ]) < 1: | |
self.reset_assigments(self.patch_assigments_3d) | |
# select one tile from one image after another | |
non_zero_xy = np.nonzero(self.patch_assigments_3d[:, :, ]) | |
index_non_zero_xy = np.random.randint(len(non_zero_xy[0])) | |
patch_x = non_zero_xy[1][index_non_zero_xy] | |
patch_y = non_zero_xy[0][index_non_zero_xy] | |
current_image = non_zero_xy[2][index_non_zero_xy] | |
self.patch_assigments_3d[patch_y, patch_x, current_image] = 0 | |
imgw, imgh = imagereader.read_width_height(self.images[current_image]['filename']) | |
tile = imagereader.get_tile_coordinates(imgw=imgw, imgh=imgh, num_patches=self.patch_assigments.shape[0], | |
patchxy=(patch_x, patch_y), overlap=self.overlap) | |
# Apply augmentation | |
start_aug = time.time() | |
img, all_objs = self.aug_image(self.images[current_image], jitter=self.jitter, region=tile) | |
end_aug = time.time() | |
aug_times.append(end_aug-start_aug) | |
# imgname = "~/rantest_" + str(np.random.randint(1000)) + ".tif" | |
# imageio.imwrite(imgname, img) | |
# print self.images[current_image]['filename'], " ", imgname | |
if self.imgdepth == 1: | |
img = img[:, :, np.newaxis] | |
# Add to training batch | |
true_box_index = 0 | |
for obj in all_objs: | |
if obj['xmax'] > obj['xmin'] and \ | |
obj['ymax'] > obj['ymin'] and \ | |
obj['name'] in self.config['LABELS']: | |
center_x = .5 * (obj['xmin'] + obj['xmax']) | |
center_y = .5 * (obj['ymin'] + obj['ymax']) | |
#print "x ", center_x, " y ", center_y | |
center_x = center_x / (float(self.config['IMAGE_W']) / self.config['GRID_W']) | |
center_y = center_y / (float(self.config['IMAGE_H']) / self.config['GRID_H']) | |
grid_x = int(np.floor(center_x)) | |
grid_y = int(np.floor(center_y)) | |
if grid_x < self.config['GRID_W'] and grid_y < self.config['GRID_H']: | |
obj_indx = self.config['LABELS'].index(obj['name']) | |
center_w = (obj['xmax'] - obj['xmin']) / ( | |
float(self.config['IMAGE_W']) / self.config['GRID_W']) # unit: grid cell | |
center_h = (obj['ymax'] - obj['ymin']) / ( | |
float(self.config['IMAGE_H']) / self.config['GRID_H']) # unit: grid cell | |
box = [center_x, center_y, center_w, center_h] | |
# find the anchor that best predicts this box | |
best_anchor = -1 | |
max_iou = -1 | |
shifted_box = BoundBox(0, | |
0, | |
center_w, | |
center_h) | |
for i in range(len(self.anchors)): | |
anchor = self.anchors[i] | |
iou = bbox_iou(shifted_box, anchor) | |
if max_iou < iou: | |
best_anchor = i | |
max_iou = iou | |
# assign ground truth x, y, w, h, confidence and class probs to y_batch | |
y_batch[instance_count, grid_y, grid_x, best_anchor, 0:4] = box | |
y_batch[instance_count, grid_y, grid_x, best_anchor, 4] = 1. | |
y_batch[instance_count, grid_y, grid_x, best_anchor, 5 + obj_indx] = 1 | |
# assign the true box to b_batch | |
b_batch[instance_count, 0, 0, 0, true_box_index] = box | |
true_box_index += 1 | |
true_box_index = true_box_index % self.config['TRUE_BOX_BUFFER'] | |
# assign input image to x_batch | |
if self.norm is not None: | |
x_batch[instance_count] = self.norm(img) | |
# increase instance counter in current batch | |
instance_count += 1 | |
self.counter += 1 | |
return [x_batch, b_batch], y_batch | |
def reset_assigments(self, array): | |
help = np.repeat(self.patch_assigments[:, :, np.newaxis], len(self.images), axis=2) | |
np.copyto(array, help) | |
#for index, x in np.ndenumerate(help): | |
# array[index] = x | |
def on_epoch_end(self): | |
with self.lock: | |
if self.shuffle: | |
np.random.shuffle(self.images) | |
#tmp = np.ctypeslib.as_array(self.patch_assigments_3d) | |
self.reset_assigments(self.patch_assigments_3d) | |
self.counter = 0 | |
def aug_image(self, train_instance, jitter, region=None): | |
image_name = train_instance['filename'] | |
image = imagereader.image_read(image_name, region) | |
h = image.shape[0] | |
w = image.shape[1] | |
all_objs = copy.deepcopy(train_instance['object']) | |
if jitter: | |
# scale the image | |
scale = np.random.uniform() / 10. + 1. | |
image = cv2.resize(image, (0, 0), fx=scale, fy=scale) | |
# translate the image | |
max_offx = (scale - 1.) * w | |
max_offy = (scale - 1.) * h | |
offx = int(np.random.uniform() * max_offx) | |
offy = int(np.random.uniform() * max_offy) | |
image = image[offy: (offy + h), offx: (offx + w)] | |
# flip the image | |
flip_selection = np.random.randint(0, 2) | |
flip_vertical = flip_selection == 1 | |
flip_horizontal = flip_selection == 2 | |
flip_both = flip_selection == 3 | |
if flip_vertical: | |
image = cv2.flip(image, 1) | |
if flip_horizontal: | |
image = cv2.flip(image, 0) | |
if flip_both: | |
image = cv2.flip(image, -1) | |
is_grey = (np.issubdtype(image.dtype, np.int8) or np.issubdtype(image.dtype, np.uint8)) | |
aug = Augmentation(is_grey) | |
image = aug.image_augmentation(image) | |
# resize the image to standard size | |
image = cv2.resize(image, (self.config['IMAGE_H'], self.config['IMAGE_W'])) | |
if self.imgdepth == 3: | |
image = image[:, :, ::-1] | |
# fix objects's position and size and check region | |
obj_is_region = [] | |
for obj in all_objs: | |
if region is None: | |
is_in_region = True | |
else: | |
bwidth = (obj['xmax'] - obj['xmin']) | |
bheight = (obj['ymax'] - obj['ymin']) | |
obj_center_x = int(obj['xmax'] - bwidth / 2) | |
obj_center_y = int(obj['ymax'] - bheight / 2) | |
region_x_start = int(region[0].start + (bwidth/2)*0.9) | |
region_x_stop = int(region[0].stop - (bwidth/2)*0.9) | |
region_y_start = int(region[1].start + (bheight/2)*0.9) | |
region_y_stop = int(region[1].stop - (bheight / 2) * 0.9) | |
is_in_region = obj_center_x in range(region_x_start, region_x_stop) and \ | |
obj_center_y in range(region_y_start, region_y_stop) | |
if is_in_region: | |
region_off_x = 0 | |
region_off_y = 0 | |
if region is not None: | |
region_off_x = region[0].start | |
region_off_y = region[1].start | |
for attr in ['xmin', 'xmax']: | |
obj[attr] = obj[attr] - region_off_x | |
if jitter: | |
obj[attr] = int(obj[attr] * scale - offx) | |
obj[attr] = int(obj[attr] * float(self.config['IMAGE_W']) / w) | |
obj[attr] = max(min(obj[attr], self.config['IMAGE_W']), 0) | |
for attr in ['ymin', 'ymax']: | |
obj[attr] = obj[attr] - region_off_y | |
if jitter: | |
obj[attr] = int(obj[attr] * scale - offy) | |
obj[attr] = int(obj[attr] * float(self.config['IMAGE_H']) / h) | |
obj[attr] = max(min(obj[attr], self.config['IMAGE_H']), 0) | |
if jitter and (flip_vertical or flip_both): | |
xmin = obj['xmin'] | |
obj['xmin'] = self.config['IMAGE_W'] - obj['xmax'] | |
obj['xmax'] = self.config['IMAGE_W'] - xmin | |
if jitter and (flip_horizontal or flip_both): | |
ymin = obj['ymin'] | |
obj['ymin'] = self.config['IMAGE_H'] - obj['ymax'] | |
obj['ymax'] = self.config['IMAGE_H'] - ymin | |
obj_is_region.append(obj) | |
return image, obj_is_region |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment