Skip to content

Instantly share code, notes, and snippets.

Last active November 18, 2019 09:55
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
class DataLoader(object):
def __init__(self, image_size, batch_size):
self.image_size = image_size
self.batch_size = batch_size
# 80% train data, 10% validation data, 10% test data
split_weights = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)
(self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)
# Get the number of train examples
self.num_train_examples = self.metadata.splits['train'].num_examples*80/100
self.get_label_name = self.metadata.features['label'].int2str
# Pre-process data
# Resize all images to image_size x image_size
def _prepare_data(self):
self.train_data =
self.validation_data =
self.test_data =
# Resize one image to image_size x image_size
def _resize_sample(self, image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (self.image_size, self.image_size))
return image, label
def _prepare_batches(self):
self.train_batches = self.train_data.shuffle(1000).batch(self.batch_size)
self.validation_batches = self.validation_data.batch(self.batch_size)
self.test_batches = self.test_data.batch(self.batch_size)
# Get defined number of not processed images
def get_random_raw_images(self, num_of_images):
random_train_raw_data = self.train_data_raw.shuffle(1000)
return random_train_raw_data.take(num_of_images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment