Skip to content

Instantly share code, notes, and snippets.

@noahtren
Last active June 4, 2020 19:17
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noahtren/1fa899d1df7da78dd4ad1557dc279e7b to your computer and use it in GitHub Desktop.
Save noahtren/1fa899d1df7da78dd4ad1557dc279e7b to your computer and use it in GitHub Desktop.
TPU-Compatible Differentiable Affine Transformations
"""Originally from https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96#Data-Augmentation,
modified to be in pure TensorFlow and to work on a batch of images rather than a single image.
(For a tf.data pipeline, you may want to look at the original code at the link above.)
"""
import math
import tensorflow as tf
def transform_batch(images,
max_rot_deg,
max_shear_deg,
max_zoom_diff_pct,
max_shift_pct,
experimental_tpu_efficiency=True):
"""Affine transformation on a batch of square images.
"""
def clipped_random():
# Random number from -1 to 1; clipped normal distribution
rand = tf.random.normal([1], dtype=tf.float32)
rand = tf.clip_by_value(rand, -2., 2.) / 2.
return rand
batch_size = images.shape[0]
tf.debugging.assert_equal(
images.shape[1],
images.shape[2],
"Images should be square")
DIM = images.shape[1]
XDIM = DIM % 2
rot = max_rot_deg * clipped_random()
shr = max_shear_deg * clipped_random()
h_zoom = 1.0 + clipped_random() * max_zoom_diff_pct
w_zoom = 1.0 + clipped_random() * max_zoom_diff_pct
h_shift = clipped_random() * (DIM * max_shift_pct)
w_shift = clipped_random() * (DIM * max_shift_pct)
# GET TRANSFORMATION MATRIX
m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift)
# LIST DESTINATION PIXEL INDICES
x = tf.repeat(tf.range(DIM//2,-DIM//2,-1), DIM)
y = tf.tile(tf.range(-DIM//2,DIM//2),[DIM])
z = tf.ones([DIM*DIM],tf.int32)
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = tf.matmul(m,tf.cast(idx,tf.float32))
idx2 = tf.cast(idx2,tf.int32)
idx2 = tf.clip_by_value(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
idx3 = tf.transpose(idx3)
batched_idx3 = tf.tile(idx3[tf.newaxis], [batch_size, 1, 1])
if experimental_tpu_efficiency:
# This reduces excessive padding from compiling the original tf.gather_nd op
idx4 = idx3[:, 0] * DIM + idx3[:, 1]
images = tf.reshape(images, [batch_size, DIM * DIM, 3])
d = tf.gather(images, idx4, axis=1)
return tf.reshape(d, [batch_size,DIM,DIM,3])
else:
d = tf.gather_nd(images, batched_idx3, batch_dims=1)
return tf.reshape(d,[batch_size,DIM,DIM,3])
if __name__ == "__main__":
# Test
import matplotlib.pyplot as plt
x = tf.random.normal((4, 100, 100, 3))
x = x - tf.math.reduce_min(x)
x = x / tf.math.reduce_max(x)
x_aug = transform(x)
fig, axes = plt.subplots(4, 2)
for b in range(4):
img = x[b]
img_aug = x_aug[b]
axes[b][0].imshow(img)
axes[b][1].imshow(img_aug)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment