Skip to content

Instantly share code, notes, and snippets.

@doleron
Created April 8, 2023 21:35
Show Gist options
  • Save doleron/2f399862a37ce34f336f3d97744c7cc6 to your computer and use it in GitHub Desktop.
Save doleron/2f399862a37ce34f336f3d97744c7cc6 to your computer and use it in GitHub Desktop.
def build_training_dataset():
pairs_tensor = tf.convert_to_tensor(training_pairs)
labels_tensor = tf.convert_to_tensor(training_pairs_labels)
result = tf.data.Dataset.from_tensor_slices((pairs_tensor, labels_tensor))
result = result.map(lambda pair, label: (load_images(pair), label))
result = result.shuffle(100, reshuffle_each_iteration=True)
result = result.repeat()
result = result.batch(TRAINING_BATCH_SIZE)
result = result.map(lambda pair, y: ((data_augmentation(pair[0], training=True),data_augmentation(pair[1], training=True)), y),
num_parallel_calls=tf.data.AUTOTUNE)
result = result.prefetch(tf.data.AUTOTUNE)
return result
train_ds = build_training_dataset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment