Skip to content

Instantly share code, notes, and snippets.

@travishsu
Created February 14, 2017 16:00
Show Gist options
  • Save travishsu/ba3f95c149aa04fcd6e9cb9df9b169a2 to your computer and use it in GitHub Desktop.
Save travishsu/ba3f95c149aa04fcd6e9cb9df9b169a2 to your computer and use it in GitHub Desktop.
從 Tensorflow 提供的 MNIST Dataset 修改來的類別,可以加入自己的數據集,並用 next_batch 提取數據集的子集。
import numpy
class DatasetNoLabel(object):
def __init__(self, data):
self._data = data
self._num_examples = data.shape[0]
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def data(self):
return self._data
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._data = self._data[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._data[start:end]
class Dataset(object):
def __init__(self, features, labels):
assert features.shape[0] == labels.shape[0], (
'features.shape: %s labels.shape: %s' % (features.shape,labels.shape))
self._num_examples = features.shape[0]
features = features.astype(numpy.float32)
self._features = features
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def features(self):
return self._features
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._features = self._features[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._features[start:end], self._labels[start:end]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment