Skip to content

Instantly share code, notes, and snippets.

@hsahovic
Created March 19, 2020 02:30
Show Gist options
  • Save hsahovic/321857a612036611937f07fd2ff6aa08 to your computer and use it in GitHub Desktop.
Save hsahovic/321857a612036611937f07fd2ff6aa08 to your computer and use it in GitHub Desktop.
This function takes a tensorflow dataset and returns a corresponding dataset implementing cutmix
def tf_ds_cutmix(ds, shuffling=1024):
ds_shuffled = ds.shuffle(shuffling)
def cutmix(p1, p2):
img_1, label_1 = p1
img_2, label_2 = p2
lambda_ = tf.random.uniform((1,))
rx = tf.random.uniform((1,), maxval = img_1.shape[0], dtype=tf.float32)
ry = tf.random.uniform((1,), maxval = img_1.shape[1], dtype=tf.float32)
rw = tf.math.sqrt(1 - lambda_) * img_1.shape[0]
rh = tf.math.sqrt(1 - lambda_) * img_1.shape[1]
x1 = tf.cast((rx - rw/2), tf.int32)
x2 = tf.cast((rx + rw/2), tf.int32)
y1 = tf.cast((ry - rh/2), tf.int32)
y2 = tf.cast((ry + rh/2), tf.int32)
x1 = tf.clip_by_value(x1, 0, img_1.shape[0])[0]
x2 = tf.clip_by_value(x2, 0, img_1.shape[0])[0]
y1 = tf.clip_by_value(y1, 0, img_1.shape[1])[0]
y2 = tf.clip_by_value(y2, 0, img_1.shape[1])[0]
lambda_ = tf.cast((x2 - x1) * (y2 - y1) / (img_1.shape[0] * img_1.shape[1]), tf.float32)
left = img_1[:,:y1,:]
middle = tf.concat([img_1[:x1,y1:y2,:],img_2[x1:x2,y1:y2,:],img_1[x2:,y1:y2,:]], axis=0)
right = img_1[:,y2:,:]
result = tf.concat(
[left, middle, right], axis=1
)
# This correspond to three-headed output
# If your output only have one value, you can just return
# result, lambda_ * label_2 + (1 - lambda_) * label_1
return result, (
lambda_ * label_2[0] + (1 - lambda_) * label_1[0],
lambda_ * label_2[1] + (1 - lambda_) * label_1[1],
lambda_ * label_2[2] + (1 - lambda_) * label_1[2],
)
ds_zipped = tf.data.Dataset.zip((ds, ds_shuffled))
return ds_zipped.map(cutmix)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment