Created
March 19, 2020 02:30
-
-
Save hsahovic/321857a612036611937f07fd2ff6aa08 to your computer and use it in GitHub Desktop.
This function takes a tensorflow dataset and returns a corresponding dataset implementing cutmix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tf_ds_cutmix(ds, shuffling=1024): | |
ds_shuffled = ds.shuffle(shuffling) | |
def cutmix(p1, p2): | |
img_1, label_1 = p1 | |
img_2, label_2 = p2 | |
lambda_ = tf.random.uniform((1,)) | |
rx = tf.random.uniform((1,), maxval = img_1.shape[0], dtype=tf.float32) | |
ry = tf.random.uniform((1,), maxval = img_1.shape[1], dtype=tf.float32) | |
rw = tf.math.sqrt(1 - lambda_) * img_1.shape[0] | |
rh = tf.math.sqrt(1 - lambda_) * img_1.shape[1] | |
x1 = tf.cast((rx - rw/2), tf.int32) | |
x2 = tf.cast((rx + rw/2), tf.int32) | |
y1 = tf.cast((ry - rh/2), tf.int32) | |
y2 = tf.cast((ry + rh/2), tf.int32) | |
x1 = tf.clip_by_value(x1, 0, img_1.shape[0])[0] | |
x2 = tf.clip_by_value(x2, 0, img_1.shape[0])[0] | |
y1 = tf.clip_by_value(y1, 0, img_1.shape[1])[0] | |
y2 = tf.clip_by_value(y2, 0, img_1.shape[1])[0] | |
lambda_ = tf.cast((x2 - x1) * (y2 - y1) / (img_1.shape[0] * img_1.shape[1]), tf.float32) | |
left = img_1[:,:y1,:] | |
middle = tf.concat([img_1[:x1,y1:y2,:],img_2[x1:x2,y1:y2,:],img_1[x2:,y1:y2,:]], axis=0) | |
right = img_1[:,y2:,:] | |
result = tf.concat( | |
[left, middle, right], axis=1 | |
) | |
# This correspond to three-headed output | |
# If your output only have one value, you can just return | |
# result, lambda_ * label_2 + (1 - lambda_) * label_1 | |
return result, ( | |
lambda_ * label_2[0] + (1 - lambda_) * label_1[0], | |
lambda_ * label_2[1] + (1 - lambda_) * label_1[1], | |
lambda_ * label_2[2] + (1 - lambda_) * label_1[2], | |
) | |
ds_zipped = tf.data.Dataset.zip((ds, ds_shuffled)) | |
return ds_zipped.map(cutmix) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment