Skip to content

Instantly share code, notes, and snippets.

@thorstenwagner
Created October 3, 2018 14:41
Show Gist options
  • Save thorstenwagner/8033f43b99d1d3a1a6a31b054d91e7fc to your computer and use it in GitHub Desktop.
Save thorstenwagner/8033f43b99d1d3a1a6a31b054d91e7fc to your computer and use it in GitHub Desktop.
PatchwiseBatchGenerator
from . import imagereader
from .augmentation import Augmentation
import cv2
import copy
from .utils import BoundBox, bbox_iou
import numpy as np
from keras.utils import Sequence
import threading
import multiprocessing
from multiprocessing import sharedctypes
import time
class PatchwiseBatchGenerator(Sequence):
def __init__(self, images,
config,
patch_assigments,
shuffle=True,
jitter=True,
norm=None,
overlap=100,
name=None,
train_times=1):
self.myname = name
self.images = images
self.imgdepth = 3
self.config = config
self.shuffle = shuffle
self.jitter = jitter
self.norm = norm
self.counter = 0
self.patch_assigments = patch_assigments
self.overlap = overlap
self.anchors = [BoundBox(0, 0, config['ANCHORS'][2 * i], config['ANCHORS'][2 * i + 1]) for i in
range(int(len(config['ANCHORS']) // 2))]
self.train_times = train_times
# Evaluate image depth
self.patch_assigments_3d = np.repeat(self.patch_assigments[:, :, np.newaxis], len(self.images), axis=2)
result = np.ctypeslib.as_ctypes(self.patch_assigments_3d)
self.patch_assigments_3d = sharedctypes.RawArray(result._type_, result)
self.patch_assigments_3d = np.ctypeslib.as_array(self.patch_assigments_3d)
self.lock = multiprocessing.Lock()
img_first = imagereader.image_read(self.images[0]['filename'])
if len(img_first.shape) == 2:
self.imgdepth = 1
elif np.all(img_first[:, :, 0] == img_first[:, :, 1]) and np.all(img_first[:, :, 0] == img_first[:, :, 2]):
self.imgdepth = 1
if shuffle:
np.random.shuffle(self.images)
def __len__(self):
length = int(np.ceil(float(np.sum(self.patch_assigments) * len(self.images))/self.config['BATCH_SIZE'] ))* self.train_times
return length
def __getitem__(self, idx):
np.random.seed()
instance_count = 0
x_batch = np.zeros(
(self.config['BATCH_SIZE'], self.config['IMAGE_W'], self.config['IMAGE_W'], self.imgdepth)) # input images
b_batch = np.zeros((self.config['BATCH_SIZE'], 1, 1, 1, self.config['TRUE_BOX_BUFFER'],
4)) # list of self.config['TRUE_self.config['BOX']_BUFFER'] GT boxes
y_batch = np.zeros((self.config['BATCH_SIZE'], self.config['GRID_H'], self.config['GRID_W'], self.config['BOX'],
4 + 1 + self.config['CLASS'])) # desired network output
aug_times = []
while instance_count < self.config['BATCH_SIZE']:
# If no unselected patches are available but still intances left then reset patch_assigments_3d
#print np.sum(np.nonzero(self.patch_assigments_3d[:, :, ]))
with self.lock:
if np.sum(self.patch_assigments_3d[:, :, ]) < 1:
self.reset_assigments(self.patch_assigments_3d)
# select one tile from one image after another
non_zero_xy = np.nonzero(self.patch_assigments_3d[:, :, ])
index_non_zero_xy = np.random.randint(len(non_zero_xy[0]))
patch_x = non_zero_xy[1][index_non_zero_xy]
patch_y = non_zero_xy[0][index_non_zero_xy]
current_image = non_zero_xy[2][index_non_zero_xy]
self.patch_assigments_3d[patch_y, patch_x, current_image] = 0
imgw, imgh = imagereader.read_width_height(self.images[current_image]['filename'])
tile = imagereader.get_tile_coordinates(imgw=imgw, imgh=imgh, num_patches=self.patch_assigments.shape[0],
patchxy=(patch_x, patch_y), overlap=self.overlap)
# Apply augmentation
start_aug = time.time()
img, all_objs = self.aug_image(self.images[current_image], jitter=self.jitter, region=tile)
end_aug = time.time()
aug_times.append(end_aug-start_aug)
# imgname = "~/rantest_" + str(np.random.randint(1000)) + ".tif"
# imageio.imwrite(imgname, img)
# print self.images[current_image]['filename'], " ", imgname
if self.imgdepth == 1:
img = img[:, :, np.newaxis]
# Add to training batch
true_box_index = 0
for obj in all_objs:
if obj['xmax'] > obj['xmin'] and \
obj['ymax'] > obj['ymin'] and \
obj['name'] in self.config['LABELS']:
center_x = .5 * (obj['xmin'] + obj['xmax'])
center_y = .5 * (obj['ymin'] + obj['ymax'])
#print "x ", center_x, " y ", center_y
center_x = center_x / (float(self.config['IMAGE_W']) / self.config['GRID_W'])
center_y = center_y / (float(self.config['IMAGE_H']) / self.config['GRID_H'])
grid_x = int(np.floor(center_x))
grid_y = int(np.floor(center_y))
if grid_x < self.config['GRID_W'] and grid_y < self.config['GRID_H']:
obj_indx = self.config['LABELS'].index(obj['name'])
center_w = (obj['xmax'] - obj['xmin']) / (
float(self.config['IMAGE_W']) / self.config['GRID_W']) # unit: grid cell
center_h = (obj['ymax'] - obj['ymin']) / (
float(self.config['IMAGE_H']) / self.config['GRID_H']) # unit: grid cell
box = [center_x, center_y, center_w, center_h]
# find the anchor that best predicts this box
best_anchor = -1
max_iou = -1
shifted_box = BoundBox(0,
0,
center_w,
center_h)
for i in range(len(self.anchors)):
anchor = self.anchors[i]
iou = bbox_iou(shifted_box, anchor)
if max_iou < iou:
best_anchor = i
max_iou = iou
# assign ground truth x, y, w, h, confidence and class probs to y_batch
y_batch[instance_count, grid_y, grid_x, best_anchor, 0:4] = box
y_batch[instance_count, grid_y, grid_x, best_anchor, 4] = 1.
y_batch[instance_count, grid_y, grid_x, best_anchor, 5 + obj_indx] = 1
# assign the true box to b_batch
b_batch[instance_count, 0, 0, 0, true_box_index] = box
true_box_index += 1
true_box_index = true_box_index % self.config['TRUE_BOX_BUFFER']
# assign input image to x_batch
if self.norm is not None:
x_batch[instance_count] = self.norm(img)
# increase instance counter in current batch
instance_count += 1
self.counter += 1
return [x_batch, b_batch], y_batch
def reset_assigments(self, array):
help = np.repeat(self.patch_assigments[:, :, np.newaxis], len(self.images), axis=2)
np.copyto(array, help)
#for index, x in np.ndenumerate(help):
# array[index] = x
def on_epoch_end(self):
with self.lock:
if self.shuffle:
np.random.shuffle(self.images)
#tmp = np.ctypeslib.as_array(self.patch_assigments_3d)
self.reset_assigments(self.patch_assigments_3d)
self.counter = 0
def aug_image(self, train_instance, jitter, region=None):
image_name = train_instance['filename']
image = imagereader.image_read(image_name, region)
h = image.shape[0]
w = image.shape[1]
all_objs = copy.deepcopy(train_instance['object'])
if jitter:
# scale the image
scale = np.random.uniform() / 10. + 1.
image = cv2.resize(image, (0, 0), fx=scale, fy=scale)
# translate the image
max_offx = (scale - 1.) * w
max_offy = (scale - 1.) * h
offx = int(np.random.uniform() * max_offx)
offy = int(np.random.uniform() * max_offy)
image = image[offy: (offy + h), offx: (offx + w)]
# flip the image
flip_selection = np.random.randint(0, 2)
flip_vertical = flip_selection == 1
flip_horizontal = flip_selection == 2
flip_both = flip_selection == 3
if flip_vertical:
image = cv2.flip(image, 1)
if flip_horizontal:
image = cv2.flip(image, 0)
if flip_both:
image = cv2.flip(image, -1)
is_grey = (np.issubdtype(image.dtype, np.int8) or np.issubdtype(image.dtype, np.uint8))
aug = Augmentation(is_grey)
image = aug.image_augmentation(image)
# resize the image to standard size
image = cv2.resize(image, (self.config['IMAGE_H'], self.config['IMAGE_W']))
if self.imgdepth == 3:
image = image[:, :, ::-1]
# fix objects's position and size and check region
obj_is_region = []
for obj in all_objs:
if region is None:
is_in_region = True
else:
bwidth = (obj['xmax'] - obj['xmin'])
bheight = (obj['ymax'] - obj['ymin'])
obj_center_x = int(obj['xmax'] - bwidth / 2)
obj_center_y = int(obj['ymax'] - bheight / 2)
region_x_start = int(region[0].start + (bwidth/2)*0.9)
region_x_stop = int(region[0].stop - (bwidth/2)*0.9)
region_y_start = int(region[1].start + (bheight/2)*0.9)
region_y_stop = int(region[1].stop - (bheight / 2) * 0.9)
is_in_region = obj_center_x in range(region_x_start, region_x_stop) and \
obj_center_y in range(region_y_start, region_y_stop)
if is_in_region:
region_off_x = 0
region_off_y = 0
if region is not None:
region_off_x = region[0].start
region_off_y = region[1].start
for attr in ['xmin', 'xmax']:
obj[attr] = obj[attr] - region_off_x
if jitter:
obj[attr] = int(obj[attr] * scale - offx)
obj[attr] = int(obj[attr] * float(self.config['IMAGE_W']) / w)
obj[attr] = max(min(obj[attr], self.config['IMAGE_W']), 0)
for attr in ['ymin', 'ymax']:
obj[attr] = obj[attr] - region_off_y
if jitter:
obj[attr] = int(obj[attr] * scale - offy)
obj[attr] = int(obj[attr] * float(self.config['IMAGE_H']) / h)
obj[attr] = max(min(obj[attr], self.config['IMAGE_H']), 0)
if jitter and (flip_vertical or flip_both):
xmin = obj['xmin']
obj['xmin'] = self.config['IMAGE_W'] - obj['xmax']
obj['xmax'] = self.config['IMAGE_W'] - xmin
if jitter and (flip_horizontal or flip_both):
ymin = obj['ymin']
obj['ymin'] = self.config['IMAGE_H'] - obj['ymax']
obj['ymax'] = self.config['IMAGE_H'] - ymin
obj_is_region.append(obj)
return image, obj_is_region
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment