Last active
November 1, 2022 04:15
-
-
Save etienne87/001d4dbe03e26896c7a87339fc2f8f0d to your computer and use it in GitHub Desktop.
data augmentation in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pylint: disable-all | |
from __future__ import print_function | |
import torch | |
import random | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from box import box_iou, box_clamp | |
import numpy as np | |
def tnchw_to_ntchw(batch, boxes): | |
pass | |
def ntchw_to_tnchw(batch, boxes): | |
pass | |
def random_rotate(rotation_range): | |
degree = random.uniform(-rotation_range, rotation_range) | |
theta = math.pi / 180 * degree | |
rotation_matrix = np.array([[math.cos(theta), -math.sin(theta), 0], | |
[math.sin(theta), math.cos(theta), 0], | |
[0, 0, 1]]) | |
return rotation_matrix | |
def random_translate(height_range, width_range): | |
tx = random.uniform(-height_range, height_range) | |
ty = random.uniform(-width_range, width_range) | |
translation_matrix = np.array([[1, 0, tx], | |
[0, 1, ty], | |
[0, 0, 1]]) | |
return translation_matrix | |
def random_shear(shear_range): | |
shear = random.uniform(-shear_range, shear_range) | |
shear_matrix = np.array([[1, -math.sin(shear), 0], | |
[0, math.cos(shear), 0], | |
[0, 0, 1]]) | |
return shear_range | |
def random_zoom(zoom_range): | |
zx = random.uniform(zoom_range[0], zoom_range[1]) | |
zy = random.uniform(zoom_range[0], zoom_range[1]) | |
zoom_matrix = np.array([[zx, 0, 0], | |
[0, zy, 0], | |
[0, 0, 1]]) | |
return zoom_matrix | |
def random_horizontal_flip(): | |
if np.random.randint(2): | |
mat = np.array([[-1, 0, 0], | |
[0, 1, 0], | |
[0, 0, 1]], dtype=np.float32) | |
else: | |
mat = np.eye(3) | |
return mat | |
def affine_compose(tforms): | |
tform_matrix = tforms[0] | |
for tform in tforms: | |
tform_matrix = np.dot(tform_matrix, tform) | |
return tform_matrix | |
def get_affine_matrix(): | |
tforms = [] | |
tforms.append(random_rotate(10)) | |
tforms.append(random_translate(0.1, 0.1)) | |
tforms.append(random_shear((1.0))) | |
tforms.append(random_zoom((0.5, 2))) | |
tforms.append(random_horizontal_flip()) | |
return affine_compose(tforms) | |
def get_random_homography(): | |
mat = np.eye(3) + np.random.randn(3, 3) * 0.01 #subtle deformation | |
mat = np.dot(mat, get_affine_matrix()) | |
return mat | |
class Affine(nn.Module): | |
""" | |
One Transform to rule them all!!! | |
Applies Warping on the images by selecting a random transformation. | |
""" | |
def __init__(self, batchsize=32, height=240, width=304, use_homography=False): | |
super(Affine, self).__init__() | |
self.batchsize, self.height, self.width = batchsize, height, width | |
self.use_homography = use_homography | |
self.reset_params() | |
def reset_params(self): | |
thetas = [] | |
invthetas = [] | |
for i in range(self.batchsize): | |
if self.use_homography: | |
theta = get_random_homography() | |
else: | |
theta = get_affine_matrix((self.height, self.width)) | |
theta2 = np.linalg.inv(theta) | |
theta = torch.from_numpy(theta).float() | |
theta2 = torch.from_numpy(theta2).float() | |
thetas.append(theta.unsqueeze(0)) | |
invthetas.append(theta2.unsqueeze(0)) | |
invthetas = torch.cat(invthetas) | |
thetas = torch.cat(thetas) | |
grid_h, grid_w = torch.meshgrid([torch.linspace(-1., 1., self.height), | |
torch.linspace(-1., 1., self.width)]) | |
grid = torch.cat((grid_w[None, :, :, None], | |
grid_h[None, :, :, None]), 3) | |
grid = grid.repeat(self.batchsize, 1, 1, 1) | |
for i in range(self.batchsize): | |
grid_ncd = grid[i].view(-1, 2) | |
warped_grid = torch.mm(grid_ncd, invthetas[i, :2, :]) + invthetas[i, 2] | |
if self.use_homography: | |
warped_grid = warped_grid / warped_grid[:, 2:3] | |
warped_grid = warped_grid[:, :2] | |
grid[i] = warped_grid.view(1, self.height, self.width, 2) | |
if hasattr(self, "grid"): | |
self.grid[...] = grid | |
else: | |
self.register_buffer("grid", grid) | |
if hasattr(self, "theta"): | |
self.theta[...] = thetas | |
else: | |
self.register_buffer("theta", thetas) | |
def warp_xy(self, xy, theta): | |
tmp = torch.mm(xy, theta[:2, :]) + theta[2] # N, 3 | |
if self.use_homography: | |
tmp = tmp / (tmp[:, 2:3] + 1e-8) | |
return tmp[:, :2] | |
def warp_boxes(self, bboxes): | |
for i in range(len(bboxes)): | |
boxes = bboxes[i] | |
thetai = self.theta[i] | |
tl = boxes[:, :2] | |
br = boxes[:, 2:] | |
wh = br - tl | |
tr = tl.clone() | |
bl = br.clone() | |
tr[:, 0] += wh[:, 0] | |
bl[:, 0] -= wh[:, 1] | |
tl, tr, bl, blr = self.warp_xy(tl, thetai), self.warp_xy(tr, thetai), \ | |
self.warp_xy(bl, thetai), self.warp_xy(br, thetai) | |
tcat = torch.cat([tl.unsqueeze(2), | |
tr.unsqueeze(2), | |
bl.unsqueeze(2), | |
br.unsqueeze(2)], dim=2) | |
boxes[:, :2] = torch.min(tcat, dim=2)[0] | |
boxes[:, 2:] = torch.max(tcat, dim=2)[0] | |
bboxes[i] = boxes | |
def warp_images(self, x): | |
grid = self.grid[:x.size(0)] | |
y = F.grid_sample(x, grid) | |
return y | |
def forward(self, x, boxes): | |
y = self.warp_images(x) | |
bbox = [] | |
if isinstance(boxes[0], list): | |
for i, box in enumerate(boxes): | |
self.warp_boxes(box) | |
else: | |
self.warp_boxes(boxes) | |
return y, boxes | |
class SequenceBatchAugmentation(object): | |
""" | |
Takes Batches & applies a different BUT non-changing | |
transform to every sample of the batch. | |
""" | |
def __init__(self, dataloader, format="TNCHW"): | |
self.dataloader = dataloader | |
self.homographer = Affine(dataloader.batchsize, dataloader.height, dataloader.width) | |
def __iter__(self, batch, boxes): | |
#put batch in the right format | |
#if self.format == "TNCHW": | |
batch, boxes = self.homographer(batch, boxes) | |
def make_grid(im, thumbsize=80): | |
im2 = im.reshape(im.shape[0] / thumbsize, thumbsize, im.shape[1] / thumbsize, thumbsize, 3) | |
im2 = im2.swapaxes(1, 2).reshape(-1, thumbsize, thumbsize, 3) | |
return im2 | |
def unmake_grid(batch): | |
batchsize = batch.shape[0] | |
thumbsize = batch.shape[1] | |
channels = batch.shape[-1] | |
nrows = 2 ** ((batchsize.bit_length() - 1) // 2) | |
ncols = batchsize / nrows | |
im = batch.reshape(nrows, ncols, thumbsize, thumbsize, channels) | |
im = im.swapaxes(1, 2) | |
im = im.reshape(nrows * thumbsize, ncols * thumbsize, channels) | |
return im | |
def test_pokemon(): | |
#path = https://www.google.fr/url?sa=i&source=images&cd=&cad=rja&uact=8&ved=2ahUKEwjDyovHooLhAhWQlRQKHThtBJMQjRx6BAgBEAU&url=http%3A%2F%2Fsecret-world-of-pokemon.blog.cz%2F1110%2Fdalsi-veci-ke-stazeni&psig=AOvVaw3UcS6NTNQGr8ogHmQxFPD9&ust=1552674792972448 | |
#import urllib | |
#urllib.urlretrieve(path, "pokemon3.png") | |
#im = cv2.imread("/localhome/eperot/pokemon3.png") | |
im = cv2.imread("/localhome/eperot/pokemon3.png") | |
batchsize = 32 | |
cuda = 0 | |
# stupid dataset maker | |
thumbsize = 80 | |
im2 = im.reshape(im.shape[0] / thumbsize, thumbsize, im.shape[1] / thumbsize, thumbsize, 3) | |
im2 = im2.swapaxes(1, 2).reshape(-1, thumbsize, thumbsize, 3) | |
im2[np.where(im2 == 0)] = 127 | |
# Create Fake Boxes | |
default = np.array([ | |
[-0.75, -0.75, 0.75, 0.75] | |
], dtype=np.float32) | |
default = torch.from_numpy(default) | |
batch_boxes = [default.clone() for i in range(batchsize)] | |
module = Affine(batchsize=batchsize, height=thumbsize, width=thumbsize, use_homography=True) | |
if cuda: | |
module = module.cuda() | |
print('Start') | |
num_batches = im2.shape[0] / batchsize | |
for i in range(num_batches * 100): | |
st, en = (i * batchsize) % im2.shape[0], min(im.shape[0], (i + 1) * batchsize) | |
en = min(st + 32, en) | |
batch = im2[st:en] | |
if batch.shape[0] == 0: | |
continue | |
original = unmake_grid(batch) | |
batch = torch.from_numpy(batch) | |
batch = batch.permute([0, 3, 1, 2]) | |
tmp_batch_boxes = [item.clone() for item in batch_boxes] | |
if cuda: | |
batch = batch.cuda() | |
tmp_batch_boxes = [item.cuda() for item in tmp_batch_boxes] | |
batch = batch.float() | |
torch.cuda.synchronize() | |
start = time.time() | |
y, out_boxes = module(batch, tmp_batch_boxes) | |
torch.cuda.synchronize() | |
end = time.time() | |
print(end - start, ' s') | |
y = y.permute([0, 2, 3, 1]) | |
ynp = y.cpu().numpy().astype(np.uint8) | |
# draw on ynp the transformed boxes | |
for i in range(ynp.shape[0]): | |
outbox = out_boxes[i].cpu().numpy() | |
cpy = ynp[i].copy() | |
for j in range(outbox.shape[0]): | |
box = (outbox[j] + 1) / 2 | |
box *= thumbsize # warning ! TODO: mul by width, height | |
pt1 = (int(box[0]), int(box[1])) | |
pt2 = (int(box[2]), int(box[3])) | |
cv2.rectangle(cpy, pt1, pt2, (0, 0, 255), 1) | |
ynp[i] = cpy | |
image = unmake_grid(ynp) | |
cv2.imshow('original', original) | |
cv2.imshow('transformed', image) | |
key = cv2.waitKey(0) | |
if key == 27: | |
break | |
# module.reset_params() | |
if __name__ == '__main__': | |
# load a bunch of images, apply Affine | |
import cv2 | |
import time | |
test_pokemon() | |
#test_dataset() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment