Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def __init__(self, image_size, batch_size):
self.image_size = image_size
self.batch_size = batch_size
# 80% train data, 10% validation data, 10% test data
split_weights = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)
(self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)
# Get the number of train examples
self.num_train_examples = self.metadata.splits['train'].num_examples*80/100
self.get_label_name = self.metadata.features['label'].int2str
# Pre-process data
self._prepare_data()
self._prepare_batches()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment