Skip to content

Instantly share code, notes, and snippets.

@etienne87
Last active November 1, 2022 04:15
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save etienne87/001d4dbe03e26896c7a87339fc2f8f0d to your computer and use it in GitHub Desktop.
Save etienne87/001d4dbe03e26896c7a87339fc2f8f0d to your computer and use it in GitHub Desktop.
data augmentation in pytorch
# pylint: disable-all
from __future__ import print_function
import torch
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from box import box_iou, box_clamp
import numpy as np
def tnchw_to_ntchw(batch, boxes):
pass
def ntchw_to_tnchw(batch, boxes):
pass
def random_rotate(rotation_range):
degree = random.uniform(-rotation_range, rotation_range)
theta = math.pi / 180 * degree
rotation_matrix = np.array([[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1]])
return rotation_matrix
def random_translate(height_range, width_range):
tx = random.uniform(-height_range, height_range)
ty = random.uniform(-width_range, width_range)
translation_matrix = np.array([[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
return translation_matrix
def random_shear(shear_range):
shear = random.uniform(-shear_range, shear_range)
shear_matrix = np.array([[1, -math.sin(shear), 0],
[0, math.cos(shear), 0],
[0, 0, 1]])
return shear_range
def random_zoom(zoom_range):
zx = random.uniform(zoom_range[0], zoom_range[1])
zy = random.uniform(zoom_range[0], zoom_range[1])
zoom_matrix = np.array([[zx, 0, 0],
[0, zy, 0],
[0, 0, 1]])
return zoom_matrix
def random_horizontal_flip():
if np.random.randint(2):
mat = np.array([[-1, 0, 0],
[0, 1, 0],
[0, 0, 1]], dtype=np.float32)
else:
mat = np.eye(3)
return mat
def affine_compose(tforms):
tform_matrix = tforms[0]
for tform in tforms:
tform_matrix = np.dot(tform_matrix, tform)
return tform_matrix
def get_affine_matrix():
tforms = []
tforms.append(random_rotate(10))
tforms.append(random_translate(0.1, 0.1))
tforms.append(random_shear((1.0)))
tforms.append(random_zoom((0.5, 2)))
tforms.append(random_horizontal_flip())
return affine_compose(tforms)
def get_random_homography():
mat = np.eye(3) + np.random.randn(3, 3) * 0.01 #subtle deformation
mat = np.dot(mat, get_affine_matrix())
return mat
class Affine(nn.Module):
"""
One Transform to rule them all!!!
Applies Warping on the images by selecting a random transformation.
"""
def __init__(self, batchsize=32, height=240, width=304, use_homography=False):
super(Affine, self).__init__()
self.batchsize, self.height, self.width = batchsize, height, width
self.use_homography = use_homography
self.reset_params()
def reset_params(self):
thetas = []
invthetas = []
for i in range(self.batchsize):
if self.use_homography:
theta = get_random_homography()
else:
theta = get_affine_matrix((self.height, self.width))
theta2 = np.linalg.inv(theta)
theta = torch.from_numpy(theta).float()
theta2 = torch.from_numpy(theta2).float()
thetas.append(theta.unsqueeze(0))
invthetas.append(theta2.unsqueeze(0))
invthetas = torch.cat(invthetas)
thetas = torch.cat(thetas)
grid_h, grid_w = torch.meshgrid([torch.linspace(-1., 1., self.height),
torch.linspace(-1., 1., self.width)])
grid = torch.cat((grid_w[None, :, :, None],
grid_h[None, :, :, None]), 3)
grid = grid.repeat(self.batchsize, 1, 1, 1)
for i in range(self.batchsize):
grid_ncd = grid[i].view(-1, 2)
warped_grid = torch.mm(grid_ncd, invthetas[i, :2, :]) + invthetas[i, 2]
if self.use_homography:
warped_grid = warped_grid / warped_grid[:, 2:3]
warped_grid = warped_grid[:, :2]
grid[i] = warped_grid.view(1, self.height, self.width, 2)
if hasattr(self, "grid"):
self.grid[...] = grid
else:
self.register_buffer("grid", grid)
if hasattr(self, "theta"):
self.theta[...] = thetas
else:
self.register_buffer("theta", thetas)
def warp_xy(self, xy, theta):
tmp = torch.mm(xy, theta[:2, :]) + theta[2] # N, 3
if self.use_homography:
tmp = tmp / (tmp[:, 2:3] + 1e-8)
return tmp[:, :2]
def warp_boxes(self, bboxes):
for i in range(len(bboxes)):
boxes = bboxes[i]
thetai = self.theta[i]
tl = boxes[:, :2]
br = boxes[:, 2:]
wh = br - tl
tr = tl.clone()
bl = br.clone()
tr[:, 0] += wh[:, 0]
bl[:, 0] -= wh[:, 1]
tl, tr, bl, blr = self.warp_xy(tl, thetai), self.warp_xy(tr, thetai), \
self.warp_xy(bl, thetai), self.warp_xy(br, thetai)
tcat = torch.cat([tl.unsqueeze(2),
tr.unsqueeze(2),
bl.unsqueeze(2),
br.unsqueeze(2)], dim=2)
boxes[:, :2] = torch.min(tcat, dim=2)[0]
boxes[:, 2:] = torch.max(tcat, dim=2)[0]
bboxes[i] = boxes
def warp_images(self, x):
grid = self.grid[:x.size(0)]
y = F.grid_sample(x, grid)
return y
def forward(self, x, boxes):
y = self.warp_images(x)
bbox = []
if isinstance(boxes[0], list):
for i, box in enumerate(boxes):
self.warp_boxes(box)
else:
self.warp_boxes(boxes)
return y, boxes
class SequenceBatchAugmentation(object):
"""
Takes Batches & applies a different BUT non-changing
transform to every sample of the batch.
"""
def __init__(self, dataloader, format="TNCHW"):
self.dataloader = dataloader
self.homographer = Affine(dataloader.batchsize, dataloader.height, dataloader.width)
def __iter__(self, batch, boxes):
#put batch in the right format
#if self.format == "TNCHW":
batch, boxes = self.homographer(batch, boxes)
def make_grid(im, thumbsize=80):
im2 = im.reshape(im.shape[0] / thumbsize, thumbsize, im.shape[1] / thumbsize, thumbsize, 3)
im2 = im2.swapaxes(1, 2).reshape(-1, thumbsize, thumbsize, 3)
return im2
def unmake_grid(batch):
batchsize = batch.shape[0]
thumbsize = batch.shape[1]
channels = batch.shape[-1]
nrows = 2 ** ((batchsize.bit_length() - 1) // 2)
ncols = batchsize / nrows
im = batch.reshape(nrows, ncols, thumbsize, thumbsize, channels)
im = im.swapaxes(1, 2)
im = im.reshape(nrows * thumbsize, ncols * thumbsize, channels)
return im
def test_pokemon():
#path = https://www.google.fr/url?sa=i&source=images&cd=&cad=rja&uact=8&ved=2ahUKEwjDyovHooLhAhWQlRQKHThtBJMQjRx6BAgBEAU&url=http%3A%2F%2Fsecret-world-of-pokemon.blog.cz%2F1110%2Fdalsi-veci-ke-stazeni&psig=AOvVaw3UcS6NTNQGr8ogHmQxFPD9&ust=1552674792972448
#import urllib
#urllib.urlretrieve(path, "pokemon3.png")
#im = cv2.imread("/localhome/eperot/pokemon3.png")
im = cv2.imread("/localhome/eperot/pokemon3.png")
batchsize = 32
cuda = 0
# stupid dataset maker
thumbsize = 80
im2 = im.reshape(im.shape[0] / thumbsize, thumbsize, im.shape[1] / thumbsize, thumbsize, 3)
im2 = im2.swapaxes(1, 2).reshape(-1, thumbsize, thumbsize, 3)
im2[np.where(im2 == 0)] = 127
# Create Fake Boxes
default = np.array([
[-0.75, -0.75, 0.75, 0.75]
], dtype=np.float32)
default = torch.from_numpy(default)
batch_boxes = [default.clone() for i in range(batchsize)]
module = Affine(batchsize=batchsize, height=thumbsize, width=thumbsize, use_homography=True)
if cuda:
module = module.cuda()
print('Start')
num_batches = im2.shape[0] / batchsize
for i in range(num_batches * 100):
st, en = (i * batchsize) % im2.shape[0], min(im.shape[0], (i + 1) * batchsize)
en = min(st + 32, en)
batch = im2[st:en]
if batch.shape[0] == 0:
continue
original = unmake_grid(batch)
batch = torch.from_numpy(batch)
batch = batch.permute([0, 3, 1, 2])
tmp_batch_boxes = [item.clone() for item in batch_boxes]
if cuda:
batch = batch.cuda()
tmp_batch_boxes = [item.cuda() for item in tmp_batch_boxes]
batch = batch.float()
torch.cuda.synchronize()
start = time.time()
y, out_boxes = module(batch, tmp_batch_boxes)
torch.cuda.synchronize()
end = time.time()
print(end - start, ' s')
y = y.permute([0, 2, 3, 1])
ynp = y.cpu().numpy().astype(np.uint8)
# draw on ynp the transformed boxes
for i in range(ynp.shape[0]):
outbox = out_boxes[i].cpu().numpy()
cpy = ynp[i].copy()
for j in range(outbox.shape[0]):
box = (outbox[j] + 1) / 2
box *= thumbsize # warning ! TODO: mul by width, height
pt1 = (int(box[0]), int(box[1]))
pt2 = (int(box[2]), int(box[3]))
cv2.rectangle(cpy, pt1, pt2, (0, 0, 255), 1)
ynp[i] = cpy
image = unmake_grid(ynp)
cv2.imshow('original', original)
cv2.imshow('transformed', image)
key = cv2.waitKey(0)
if key == 27:
break
# module.reset_params()
if __name__ == '__main__':
# load a bunch of images, apply Affine
import cv2
import time
test_pokemon()
#test_dataset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment