Skip to content

Instantly share code, notes, and snippets.

@sdcubber
Last active August 9, 2018 19:18
Show Gist options
  • Save sdcubber/245f991b1e103fd6485f6f6da3d6e391 to your computer and use it in GitHub Desktop.
Save sdcubber/245f991b1e103fd6485f6f6da3d6e391 to your computer and use it in GitHub Desktop.
import ast
import numpy as np
import math
import os
import random
from tensorflow.keras.preprocessing.image import img_to_array as img_to_array
from tensorflow.keras.preprocessing.image import load_img as load_img
def load_image(image_path, size):
# data augmentation logic such as random rotations can be added here
return img_to_array(load_img(image_path, target_size=(size, size))) / 255.
class KagglePlanetSequence(tf.keras.utils.Sequence):
"""
Custom Sequence object to train a model on out-of-memory datasets.
"""
def __init__(self, df_path, data_path, im_size, batch_size, mode='train'):
"""
df_path: path to a .csv file that contains columns with image names and labels
data_path: path that contains the training images
im_size: image size
mode: when in training mode, data will be shuffled between epochs
"""
self.df = pd.read_csv(df_path)
self.im_size = im_size
self.batch_size = batch_size
self.mode = mode
# Take labels and a list of image locations in memory
self.wlabels = self.df['weather_labels'].apply(lambda x: ast.literal_eval(x)).tolist()
self.glabels = self.df['ground_labels'].apply(lambda x: ast.literal_eval(x)).tolist()
self.image_list = self.df['image_name'].apply(lambda x: os.path.join(data_path, x + '.jpg')).tolist()
def __len__(self):
return int(math.ceil(len(self.df) / float(self.batch_size)))
def on_epoch_end(self):
# Shuffles indexes after each epoch
self.indexes = range(len(self.image_list))
if self.mode == 'train':
self.indexes = random.sample(self.indexes, k=len(self.indexes))
def get_batch_labels(self, idx):
# Fetch a batch of labels
return [self.wlabels[idx * self.batch_size: (idx + 1) * self.batch_size],
self.glabels[idx * self.batch_size: (idx + 1) * self.batch_size]]
def get_batch_features(self, idx):
# Fetch a batch of images
batch_images = self.image_list[idx * self.batch_size: (1 + idx) * self.batch_size]
return np.array([load_image(im, self.im_size) for im in batch_images])
def __getitem__(self, idx):
batch_x = self.get_batch_features(idx)
batch_y = self.get_batch_labels(idx)
return batch_x, batch_y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment