Skip to content

Instantly share code, notes, and snippets.



Last active Aug 8, 2017
What would you like to do?
Tensorflow dataset class
# Code adapted from TensorFlow source example:
class DataSet:
"""Base data set class
def __init__(self, shuffle=True, labeled=True, **data_dict):
assert '_data' in data_dict
if labeled:
assert '_labels' in data_dict
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0]
self._labeled = labeled
self._shuffle = shuffle
self._num_samples = self._data.shape[0]
self._index_in_epoch = 0
if self._shuffle:
def __len__(self):
return len(self._data)
def index_in_epoch(self):
return self._index_in_epoch
def num_samples(self):
return self._num_samples
def data(self):
return self._data
def labels(self):
return self._labels
def labeled(self):
return self._labeled
def test_data(self):
return self._test_data
def test_labels(self):
return self._test_labels
def load(cls, filename):
data_dict = np.load(filename)
return cls(**data_dict)
def save(self, filename):
data_dict = self.__dict__
np.savez_compressed(filename, **data_dict)
def _shuffle_data(self):
shuffled_idx = np.arange(self._num_samples)
self._data = self._data[shuffled_idx]
if self._labeled:
self._labels = self._labels[shuffled_idx]
def next_batch(self, batch_size):
assert batch_size <= self._num_samples
start = self._index_in_epoch
if start + batch_size > self._num_samples:
data_batch = self._data[start:]
if self._labeled:
labels_batch = self._labels[start:]
remaining = batch_size - (self._num_samples - start)
if self._shuffle:
start = 0
data_batch = np.concatenate([data_batch, self._data[:remaining]],
if self._labeled:
labels_batch = np.concatenate([labels_batch,
self._index_in_epoch = remaining
data_batch = self._data[start:start + batch_size]
if self._labeled:
labels_batch = self._labels[start:start + batch_size]
self._index_in_epoch = start + batch_size
batch = (data_batch, labels_batch) if self._labeled else data_batch
return batch
from sklearn import datasets
from sklearn.model_selection import train_test_split
iris = datasets.load_iris()
X =
y =
train_X, train_y, test_X, test_y = train_test_split(X, y, train_size=0.9)
data_dict = {
'_data': train_X,
'_labels': train_y,
'_test_data': test_X,
'_test_labels': test_y
iris_data = Dataset(**data_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment