Skip to content

Instantly share code, notes, and snippets.

Created August 1, 2019 16:59
Show Gist options
  • Save gilr00y/a263e447dedeee6f9a108a09b28e07f4 to your computer and use it in GitHub Desktop.
Save gilr00y/a263e447dedeee6f9a108a09b28e07f4 to your computer and use it in GitHub Desktop.
Torch - Sliding Window transform for Object Detection
import numpy as np
import torch
class WindowUtils:
def is_overlapping(window, annot):
crop_tolerance = 0.25
win_min_x, win_min_y, win_max_x, win_max_y = window
ann_min_x, ann_min_y, ann_max_x, ann_max_y, cat = annot
orig_ann_area = (ann_max_x - ann_min_x) * (ann_max_y - ann_min_y)
# Only return True if overlapping_area > crop_tolerance * original_area
min_crop_area = crop_tolerance * orig_ann_area
if (ann_min_x < win_max_x and
ann_max_x > win_min_x and
ann_min_y < win_max_y and
ann_max_y > win_min_y and
(np.min([ann_max_x, win_max_x]) - np.max([ann_min_x, win_min_x])) *
(np.min([ann_max_y, win_max_y]) - np.max([ann_min_y, win_min_y])) >= min_crop_area
return True
return False
def calculate_sub_img_annot(window, ann):
win_min_x, win_min_y, win_max_x, win_max_y = window
ann_min_x, ann_min_y, ann_max_x, ann_max_y, cat = ann
sub_img_width = win_max_x - win_min_x
sub_img_height = win_max_y - win_min_y
sub_min_x = np.max([ann_min_x - win_min_x, 0.])
sub_max_x = np.min([ann_max_x - win_min_x, sub_img_width - 1])
sub_min_y = np.max([ann_min_y - win_min_y, 0.])
sub_max_y = np.min([ann_max_y - win_min_y, sub_img_height - 1])
return [sub_min_x, sub_min_y, sub_max_x, sub_max_y, cat]
def get_annotations_for_window(window, annots):
sub_img_annots = []
for ann in annots:
if WindowUtils.is_overlapping(window, ann):
sub_img_annots.append(WindowUtils.calculate_sub_img_annot(window, ann))
return sub_img_annots
def chip_img(img, window):
return img[window[1]:window[3], window[0]:window[2]]
class Windower(object):
"""Returns array of samples according to number of windows."""
def __call__(self, sample, win_height=800, win_width=800, min_overlap=200):
windows = []
image, annots = sample['img'], sample['annot']
rows, cols, cns = image.shape
if rows <= win_height:
num_vertical_windows = 1
num_vertical_windows = int(np.ceil(rows / (win_height - min_overlap))) # so a 900px image would have 2 vertical windows
if cols <= win_width:
num_horizontal_windows = 1
num_horizontal_windows = int(np.ceil(cols / (win_width - min_overlap)))
# Generate x, y coords for each window
for h_win_idx in range(num_horizontal_windows):
for v_win_idx in range(num_vertical_windows):
if h_win_idx + 1 == num_horizontal_windows:
# Last horizontal window, so hug right side of image
min_x = cols - win_width
max_x = min_x + win_width - 1
min_x = h_win_idx * (win_width - min_overlap)
max_x = min_x + win_width - 1
if v_win_idx + 1 == num_vertical_windows:
# Last vertical window, so hug bottom of image
min_y = rows - win_height
max_y = min_y + win_height - 1
min_y = v_win_idx * (win_height - min_overlap)
max_y = min_y + win_height - 1
# Format (min_x, min_y, max_x, max_y) from top-left
windows.append((min_x, min_y, max_x, max_y))
# For debugging
# def plot_sub_img(image, window, annots):
# fig,ax = plt.subplots(1,figsize=(30,30))
# ax.imshow(WindowUtils.chip_img(image, window))
# for an in annots:
# ax.add_patch(
# patches.Rectangle(
# (an[0],an[1]),
# width=an[2]-an[0],
# height=an[3]-an[1],
# linewidth=5,
# edgecolor=np.random.choice(['b', 'g']),
# facecolor='none'))
# for win_idx, win in enumerate(windows):
# win_annots = WindowUtils.get_annotations_for_window(win, annots)
# if len(win_annots):
# # print('OVERLAP FOR WINDOW {}'.format(win_idx))
# plot_sub_img(image, win, win_annots)
ret = [{
'img': torch.from_numpy(WindowUtils.chip_img(image, window)),
'annot': torch.from_numpy(np.array(WindowUtils.get_annotations_for_window(window, annots))),
'sub_img_px_coords': window # To re-anchor detections in large image.
} for window in windows]
return ret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment