Created
January 22, 2017 15:20
-
-
Save matt-graham/ee66a23419127b3fcb42465c9e6cc3bd to your computer and use it in GitHub Desktop.
CIFAR data providers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cPickle | |
import gzip | |
import numpy as np | |
import os | |
DEFAULT_SEED = 1234 | |
class DataProvider(object): | |
"""Generic data provider.""" | |
def __init__(self, inputs, targets, batch_size, max_num_batches=-1, | |
shuffle_order=True, rng=None): | |
"""Create a new data provider object. | |
Args: | |
inputs (ndarray): Array of data input features of shape | |
(num_data, input_dim). | |
targets (ndarray): Array of data output targets of shape | |
(num_data, output_dim) or (num_data,) if output_dim == 1. | |
batch_size (int): Number of data points to include in each batch. | |
max_num_batches (int): Maximum number of batches to iterate over | |
in an epoch. If `max_num_batches * batch_size > num_data` then | |
only as many batches as the data can be split into will be | |
used. If set to -1 all of the data will be used. | |
shuffle_order (bool): Whether to randomly permute the order of | |
the data before each epoch. | |
rng (RandomState): A seeded random number generator. | |
""" | |
self.inputs = inputs | |
self.targets = targets | |
if batch_size < 1: | |
raise ValueError('batch_size must be >= 1') | |
self._batch_size = batch_size | |
if max_num_batches == 0 or max_num_batches < -1: | |
raise ValueError('max_num_batches must be -1 or > 0') | |
self._max_num_batches = max_num_batches | |
self._update_num_batches() | |
self.shuffle_order = shuffle_order | |
self._current_order = np.arange(inputs.shape[0]) | |
if rng is None: | |
rng = np.random.RandomState(DEFAULT_SEED) | |
self.rng = rng | |
self.new_epoch() | |
@property | |
def batch_size(self): | |
"""Number of data points to include in each batch.""" | |
return self._batch_size | |
@batch_size.setter | |
def batch_size(self, value): | |
if value < 1: | |
raise ValueError('batch_size must be >= 1') | |
self._batch_size = value | |
self._update_num_batches() | |
@property | |
def max_num_batches(self): | |
"""Maximum number of batches to iterate over in an epoch.""" | |
return self._max_num_batches | |
@max_num_batches.setter | |
def max_num_batches(self, value): | |
if value == 0 or value < -1: | |
raise ValueError('max_num_batches must be -1 or > 0') | |
self._max_num_batches = value | |
self._update_num_batches() | |
def _update_num_batches(self): | |
"""Updates number of batches to iterate over.""" | |
# maximum possible number of batches is equal to number of whole times | |
# batch_size divides in to the number of data points which can be | |
# found using integer division | |
possible_num_batches = self.inputs.shape[0] // self.batch_size | |
if self.max_num_batches == -1: | |
self.num_batches = possible_num_batches | |
else: | |
self.num_batches = min(self.max_num_batches, possible_num_batches) | |
def __iter__(self): | |
"""Implements Python iterator interface. | |
This should return an object implementing a `next` method which steps | |
through a sequence returning one element at a time and raising | |
`StopIteration` when at the end of the sequence. Here the object | |
returned is the DataProvider itself. | |
""" | |
return self | |
def new_epoch(self): | |
"""Starts a new epoch (pass through data), possibly shuffling first.""" | |
self._curr_batch = 0 | |
if self.shuffle_order: | |
self.shuffle() | |
def reset(self): | |
"""Resets the provider to the initial state.""" | |
inv_perm = np.argsort(self._current_order) | |
self._current_order = self._current_order[inv_perm] | |
self.inputs = self.inputs[inv_perm] | |
self.targets = self.targets[inv_perm] | |
self.new_epoch() | |
def shuffle(self): | |
"""Randomly shuffles order of data.""" | |
perm = self.rng.permutation(self.inputs.shape[0]) | |
self._current_order = self._current_order[perm] | |
self.inputs = self.inputs[perm] | |
self.targets = self.targets[perm] | |
def next(self): | |
"""Returns next data batch or raises `StopIteration` if at end.""" | |
if self._curr_batch + 1 > self.num_batches: | |
# no more batches in current iteration through data set so start | |
# new epoch ready for another pass and indicate iteration is at end | |
self.new_epoch() | |
raise StopIteration() | |
# create an index slice corresponding to current batch number | |
batch_slice = slice(self._curr_batch * self.batch_size, | |
(self._curr_batch + 1) * self.batch_size) | |
inputs_batch = self.inputs[batch_slice] | |
targets_batch = self.targets[batch_slice] | |
self._curr_batch += 1 | |
return inputs_batch, targets_batch | |
class OneOfKDataProvider(DataProvider): | |
"""1-of-K classification target data provider. | |
Transforms integer target labels to binary 1-of-K encoded targets. | |
Derived classes must set self.num_classes appropriately. | |
""" | |
def next(self): | |
"""Returns next data batch or raises `StopIteration` if at end.""" | |
inputs_batch, targets_batch = super(OneOfKDataProvider, self).next() | |
return inputs_batch, self.to_one_of_k(targets_batch) | |
def to_one_of_k(self, int_targets): | |
"""Converts integer coded class target to 1-of-K coded targets. | |
Args: | |
int_targets (ndarray): Array of integer coded class targets (i.e. | |
where an integer from 0 to `num_classes` - 1 is used to | |
indicate which is the correct class). This should be of shape | |
(num_data,). | |
Returns: | |
Array of 1-of-K coded targets i.e. an array of shape | |
(num_data, num_classes) where for each row all elements are equal | |
to zero except for the column corresponding to the correct class | |
which is equal to one. | |
""" | |
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) | |
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 | |
return one_of_k_targets | |
class MNISTDataProvider(OneOfKDataProvider): | |
"""Data provider for MNIST handwritten digit images.""" | |
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, | |
shuffle_order=True, rng=None): | |
"""Create a new MNIST data provider object. | |
Args: | |
which_set: One of 'train', 'valid' or 'test'. Determines which | |
portion of the MNIST data this object should provide. | |
batch_size (int): Number of data points to include in each batch. | |
max_num_batches (int): Maximum number of batches to iterate over | |
in an epoch. If `max_num_batches * batch_size > num_data` then | |
only as many batches as the data can be split into will be | |
used. If set to -1 all of the data will be used. | |
shuffle_order (bool): Whether to randomly permute the order of | |
the data before each epoch. | |
rng (RandomState): A seeded random number generator. | |
""" | |
# check a valid which_set was provided | |
assert which_set in ['train', 'valid', 'test'], ( | |
'Expected which_set to be either train, valid or test. ' | |
'Got {0}'.format(which_set) | |
) | |
self.which_set = which_set | |
self.num_classes = 10 | |
# construct path to data using os.path.join to ensure the correct path | |
# separator for the current platform / OS is used | |
# MLP_DATA_DIR environment variable should point to the data directory | |
data_path = os.path.join( | |
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set)) | |
assert os.path.isfile(data_path), ( | |
'Data file does not exist at expected path: ' + data_path | |
) | |
# use a context-manager to ensure the files are properly closed after | |
# we are finished with them | |
with gzip.open(data_path) as f: | |
inputs, targets = cPickle.load(f) | |
# pass the loaded data to the parent class __init__ | |
super(MNISTDataProvider, self).__init__( | |
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) | |
class CIFAR10DataProvider(OneOfKDataProvider): | |
"""Data provider for CIFAR-10 object images.""" | |
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, | |
shuffle_order=True, rng=None): | |
"""Create a new CIFAR-10 data provider object. | |
Args: | |
which_set: One of 'train' or 'valid'. Determines which | |
portion of the CIFAR-10 data this object should provide. | |
batch_size (int): Number of data points to include in each batch. | |
max_num_batches (int): Maximum number of batches to iterate over | |
in an epoch. If `max_num_batches * batch_size > num_data` then | |
only as many batches as the data can be split into will be | |
used. If set to -1 all of the data will be used. | |
shuffle_order (bool): Whether to randomly permute the order of | |
the data before each epoch. | |
rng (RandomState): A seeded random number generator. | |
""" | |
# check a valid which_set was provided | |
assert which_set in ['train', 'valid'], ( | |
'Expected which_set to be either train or valid. ' | |
'Got {0}'.format(which_set) | |
) | |
self.which_set = which_set | |
self.num_classes = 10 | |
# construct path to data using os.path.join to ensure the correct path | |
# separator for the current platform / OS is used | |
# MLP_DATA_DIR environment variable should point to the data directory | |
data_path = os.path.join( | |
os.environ['MLP_DATA_DIR'], 'cifar-10-{0}.npz'.format(which_set)) | |
assert os.path.isfile(data_path), ( | |
'Data file does not exist at expected path: ' + data_path | |
) | |
# load data from compressed numpy file | |
loaded = np.load(data_path) | |
inputs, targets = loaded['inputs'], loaded['targets'] | |
# label map gives strings corresponding to integer label targets | |
self.label_map = loaded['label_map'] | |
# pass the loaded data to the parent class __init__ | |
super(CIFAR10DataProvider, self).__init__( | |
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) | |
class CIFAR100DataProvider(OneOfKDataProvider): | |
"""Data provider for CIFAR-100 object images.""" | |
def __init__(self, which_set='train', use_coarse_targets=False, | |
batch_size=100, max_num_batches=-1, | |
shuffle_order=True, rng=None): | |
"""Create a new CIFAR-100 data provider object. | |
Args: | |
which_set: One of 'train' or 'valid'. Determines which | |
portion of the CIFAR-100 data this object should provide. | |
use_coarse_targets: Whether to use coarse 'superclass' labels as | |
targets instead of standard class labels. | |
batch_size (int): Number of data points to include in each batch. | |
max_num_batches (int): Maximum number of batches to iterate over | |
in an epoch. If `max_num_batches * batch_size > num_data` then | |
only as many batches as the data can be split into will be | |
used. If set to -1 all of the data will be used. | |
shuffle_order (bool): Whether to randomly permute the order of | |
the data before each epoch. | |
rng (RandomState): A seeded random number generator. | |
""" | |
# check a valid which_set was provided | |
assert which_set in ['train', 'valid'], ( | |
'Expected which_set to be either train or valid. ' | |
'Got {0}'.format(which_set) | |
) | |
self.which_set = which_set | |
self.use_coarse_targets = use_coarse_targets | |
self.num_classes = 20 if use_coarse_targets else 100 | |
# construct path to data using os.path.join to ensure the correct path | |
# separator for the current platform / OS is used | |
# MLP_DATA_DIR environment variable should point to the data directory | |
data_path = os.path.join( | |
os.environ['MLP_DATA_DIR'], 'cifar-100-{0}.npz'.format(which_set)) | |
assert os.path.isfile(data_path), ( | |
'Data file does not exist at expected path: ' + data_path | |
) | |
# load data from compressed numpy file | |
loaded = np.load(data_path) | |
inputs, targets = loaded['inputs'], loaded['targets'] | |
targets = targets[:, 1] if use_coarse_targets else targets[:, 0] | |
# label map gives strings corresponding to integer label targets | |
self.label_map = ( | |
loaded['coarse_label_map'] | |
if use_coarse_targets else | |
loaded['fine_label_map'] | |
) | |
# pass the loaded data to the parent class __init__ | |
super(CIFAR100DataProvider, self).__init__( | |
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment