Skip to content

Instantly share code, notes, and snippets.

@agvdndor
Last active March 27, 2022 14:20
Show Gist options
  • Save agvdndor/eb92a5fd8874b5bd587e56cb9ac8ec84 to your computer and use it in GitHub Desktop.
Save agvdndor/eb92a5fd8874b5bd587e56cb9ac8ec84 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow_datasets as tfds
from aug.autoaugment import distort_image_with_autoaugment
def aug_fn(sample, policy="v0"):
"""Apply an AutoAugment policy to a dataset sample. Make sure to check
the format of your dataset to call the distort function with the appropriate
arguments. bboxes should have format [y_min, x_min, y_max, x_max] with values
normalized to [0,1]."""
aug_img, aug_bboxes = distort_image_with_autoaugment(image=sample["image"],
bboxes=sample["objects"]["bbox"],
augmentation_name=policy)
sample["image"], sample["objects"]["bbox"] = aug_img, aug_bboxes
# download pascal voc dataset
train_ds = tfds.load(name="voc/2007", split='train')
# apply augmentation
aug_train_ds = train_ds.map(aug_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment