Skip to content

Instantly share code, notes, and snippets.

@sdcubber
Created September 15, 2018 12:58
Show Gist options
  • Save sdcubber/810591500efb3979cb3418bed6cc143c to your computer and use it in GitHub Desktop.
Save sdcubber/810591500efb3979cb3418bed6cc143c to your computer and use it in GitHub Desktop.
from keras.utils import Sequence
class DataSequence(Sequence):
"""
Keras Sequence object to train a model on larger-than-memory data.
"""
def __init__(self, df, batch_size, mode='train'):
...
def get_batch_images(self, idx):
# Read a batch of images from disk
return np.array([imread(im) for im in self.im_list[idx * self.bsz: (1 + idx) * self.bsz]])
def __getitem__(self, idx):
batch_x = self.get_batch_images(idx)
batch_y = self.get_batch_labels(idx)
return batch_x, batch_y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment