Skip to content

Instantly share code, notes, and snippets.

@krishpop
Last active August 8, 2017 18:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krishpop/f352dcc7beeee5f14ef65ee8fc012f88 to your computer and use it in GitHub Desktop.
Save krishpop/f352dcc7beeee5f14ef65ee8fc012f88 to your computer and use it in GitHub Desktop.
Tensorflow dataset class
# Code adapted from TensorFlow source example:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py
class DataSet:
"""Base data set class
"""
def __init__(self, shuffle=True, labeled=True, **data_dict):
assert '_data' in data_dict
if labeled:
assert '_labels' in data_dict
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0]
self._labeled = labeled
self._shuffle = shuffle
self.__dict__.update(data_dict)
self._num_samples = self._data.shape[0]
self._index_in_epoch = 0
if self._shuffle:
self._shuffle_data()
def __len__(self):
return len(self._data)
@property
def index_in_epoch(self):
return self._index_in_epoch
@property
def num_samples(self):
return self._num_samples
@property
def data(self):
return self._data
@property
def labels(self):
return self._labels
@property
def labeled(self):
return self._labeled
@property
def test_data(self):
return self._test_data
@property
def test_labels(self):
return self._test_labels
@classmethod
def load(cls, filename):
data_dict = np.load(filename)
return cls(**data_dict)
def save(self, filename):
data_dict = self.__dict__
np.savez_compressed(filename, **data_dict)
def _shuffle_data(self):
shuffled_idx = np.arange(self._num_samples)
np.random.shuffle(shuffled_idx)
self._data = self._data[shuffled_idx]
if self._labeled:
self._labels = self._labels[shuffled_idx]
def next_batch(self, batch_size):
assert batch_size <= self._num_samples
start = self._index_in_epoch
if start + batch_size > self._num_samples:
data_batch = self._data[start:]
if self._labeled:
labels_batch = self._labels[start:]
remaining = batch_size - (self._num_samples - start)
if self._shuffle:
self._shuffle_data()
start = 0
data_batch = np.concatenate([data_batch, self._data[:remaining]],
axis=0)
if self._labeled:
labels_batch = np.concatenate([labels_batch,
self._labels[:remaining]],
axis=0)
self._index_in_epoch = remaining
else:
data_batch = self._data[start:start + batch_size]
if self._labeled:
labels_batch = self._labels[start:start + batch_size]
self._index_in_epoch = start + batch_size
batch = (data_batch, labels_batch) if self._labeled else data_batch
return batch
from sklearn import datasets
from sklearn.model_selection import train_test_split
iris = datasets.load_iris()
X = iris.data
y = iris.target
train_X, train_y, test_X, test_y = train_test_split(X, y, train_size=0.9)
data_dict = {
'_data': train_X,
'_labels': train_y,
'_test_data': test_X,
'_test_labels': test_y
}
iris_data = Dataset(**data_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment