Skip to content

Instantly share code, notes, and snippets.

@therne
Created July 9, 2016 13:31
Show Gist options
  • Save therne/3b6f1db728b78d6125647229884a574b to your computer and use it in GitHub Desktop.
Save therne/3b6f1db728b78d6125647229884a574b to your computer and use it in GitHub Desktop.
Batch data loader for minibatch training
import copy
import numpy as np
class DataSet:
def __init__(self, data, batch_size=1, shuffle=True, name="dataset"):
assert batch_size <= len(data), "batch size cannot be greater than data size."
self.name = name
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.count = len(self.data)
self.setup()
def setup(self):
self.indices = list(range(self.count)) # used in shuffling
self.current_index = 0
self.num_batches = int(self.count / self.batch_size)
self.reset()
def next_batch(self):
""" Get next batch data.
:return: data of batch_size
"""
assert self.has_next_batch(), "End of epoch. Call 'complete_epoch()' to reset."
from_, to = self.current_index, self.current_index + self.batch_size
cur_idxs = self.indices[from_:to]
batch = [self.data[i] for i in cur_idxs]
self.current_index += self.batch_size
return batch
def has_next_batch(self):
return self.current_index + self.batch_size <= self.count
def split_dataset(self, split_ratio):
""" Splits a data set by split_ratio.
(ex: split_ratio = 0.3 -> this set (70%) and splitted (30%))
:param split_ratio: ratio of train data
:return: val_set
"""
end_index = int(self.count * (1 - split_ratio))
assert self.count - end_index >= self.batch_size, "splitted data size cannot be smaller than batch size."
# do not (deep) copy data - just modify index list!
splitted = copy.copy(self)
splitted.count = self.count - end_index
splitted.indexes = list(range(end_index, self.count))
splitted.num_batches = int(splitted.count / splitted.batch_size)
self.count = end_index
self.setup()
return splitted
def reset(self):
self.current_index = 0
if self.shuffle:
np.random.shuffle(self.indices)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment