Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active April 30, 2023 07:03
Show Gist options
  • Save sayakpaul/e0024bae08afcd3d75b6d52fda191025 to your computer and use it in GitHub Desktop.
Save sayakpaul/e0024bae08afcd3d75b6d52fda191025 to your computer and use it in GitHub Desktop.
Example of incorporating RandAugment in a tf.data pipeline for image classification.
from imgaug import augmenters as iaa
import imgaug as ia
ia.seed(4)
import tensorflow as tf
tf.random.set_seed(666)
aug = iaa.RandAugment(n=2, m=9)
BATCH_SIZE = 224
def augment(images):
return aug(images=images.numpy())
# Function to read the TFRecords, segregate the images and labels
def read_tfrecord(example):
features = {
"image": tf.io.FixedLenFeature([], tf.string),
"class": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(example, features)
image = tf.image.decode_jpeg(example['image'], channels=3)
class_label = tf.cast(example['class'], tf.int32)
return (image, class_label)
# Load the TFRecords and create tf.data.Dataset
def load_dataset(filenames):
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
return dataset
# Batch, shuffle, and repeat the dataset and pre-fetch it
# well before the current epoch ends
def batch_dataset(filenames, batch_size=BATCH_SIZE, train=True):
opt = tf.data.Options()
opt.experimental_deterministic = False
dataset = load_dataset(filenames)
if train:
dataset = dataset.repeat()
dataset = dataset.shuffle(BATCH_SIZE*10)
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda x, y: (tf.py_function(augment, [x], [tf.float32]), y),
num_parallel_calls=AUTO)
dataset = dataset.map(lambda x, y: (tf.squeeze(x), y),
num_parallel_calls=AUTO)
else:
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(AUTO)
dataset = dataset.with_options(opt)
return dataset
train_pattern = "train_tfr_224/*.tfrec"
train_filenames = tf.io.gfile.glob(train_pattern)
val_pattern = "val_tfr_224/*.tfrec"
val_filenames = tf.io.gfile.glob(val_pattern)
training_ds = batch_dataset(train_filenames)
validation_ds = batch_dataset(val_filenames, train=False)
@sayakpaul
Copy link
Author

Thanks to @DarshanDeshpande for catching a one-off bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment