Skip to content

Instantly share code, notes, and snippets.

@matt-graham
Created January 22, 2017 15:20
Show Gist options
  • Save matt-graham/ee66a23419127b3fcb42465c9e6cc3bd to your computer and use it in GitHub Desktop.
Save matt-graham/ee66a23419127b3fcb42465c9e6cc3bd to your computer and use it in GitHub Desktop.
CIFAR data providers
import cPickle
import gzip
import numpy as np
import os
DEFAULT_SEED = 1234
class DataProvider(object):
"""Generic data provider."""
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new data provider object.
Args:
inputs (ndarray): Array of data input features of shape
(num_data, input_dim).
targets (ndarray): Array of data output targets of shape
(num_data, output_dim) or (num_data,) if output_dim == 1.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
self.inputs = inputs
self.targets = targets
if batch_size < 1:
raise ValueError('batch_size must be >= 1')
self._batch_size = batch_size
if max_num_batches == 0 or max_num_batches < -1:
raise ValueError('max_num_batches must be -1 or > 0')
self._max_num_batches = max_num_batches
self._update_num_batches()
self.shuffle_order = shuffle_order
self._current_order = np.arange(inputs.shape[0])
if rng is None:
rng = np.random.RandomState(DEFAULT_SEED)
self.rng = rng
self.new_epoch()
@property
def batch_size(self):
"""Number of data points to include in each batch."""
return self._batch_size
@batch_size.setter
def batch_size(self, value):
if value < 1:
raise ValueError('batch_size must be >= 1')
self._batch_size = value
self._update_num_batches()
@property
def max_num_batches(self):
"""Maximum number of batches to iterate over in an epoch."""
return self._max_num_batches
@max_num_batches.setter
def max_num_batches(self, value):
if value == 0 or value < -1:
raise ValueError('max_num_batches must be -1 or > 0')
self._max_num_batches = value
self._update_num_batches()
def _update_num_batches(self):
"""Updates number of batches to iterate over."""
# maximum possible number of batches is equal to number of whole times
# batch_size divides in to the number of data points which can be
# found using integer division
possible_num_batches = self.inputs.shape[0] // self.batch_size
if self.max_num_batches == -1:
self.num_batches = possible_num_batches
else:
self.num_batches = min(self.max_num_batches, possible_num_batches)
def __iter__(self):
"""Implements Python iterator interface.
This should return an object implementing a `next` method which steps
through a sequence returning one element at a time and raising
`StopIteration` when at the end of the sequence. Here the object
returned is the DataProvider itself.
"""
return self
def new_epoch(self):
"""Starts a new epoch (pass through data), possibly shuffling first."""
self._curr_batch = 0
if self.shuffle_order:
self.shuffle()
def reset(self):
"""Resets the provider to the initial state."""
inv_perm = np.argsort(self._current_order)
self._current_order = self._current_order[inv_perm]
self.inputs = self.inputs[inv_perm]
self.targets = self.targets[inv_perm]
self.new_epoch()
def shuffle(self):
"""Randomly shuffles order of data."""
perm = self.rng.permutation(self.inputs.shape[0])
self._current_order = self._current_order[perm]
self.inputs = self.inputs[perm]
self.targets = self.targets[perm]
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
if self._curr_batch + 1 > self.num_batches:
# no more batches in current iteration through data set so start
# new epoch ready for another pass and indicate iteration is at end
self.new_epoch()
raise StopIteration()
# create an index slice corresponding to current batch number
batch_slice = slice(self._curr_batch * self.batch_size,
(self._curr_batch + 1) * self.batch_size)
inputs_batch = self.inputs[batch_slice]
targets_batch = self.targets[batch_slice]
self._curr_batch += 1
return inputs_batch, targets_batch
class OneOfKDataProvider(DataProvider):
"""1-of-K classification target data provider.
Transforms integer target labels to binary 1-of-K encoded targets.
Derived classes must set self.num_classes appropriately.
"""
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
inputs_batch, targets_batch = super(OneOfKDataProvider, self).next()
return inputs_batch, self.to_one_of_k(targets_batch)
def to_one_of_k(self, int_targets):
"""Converts integer coded class target to 1-of-K coded targets.
Args:
int_targets (ndarray): Array of integer coded class targets (i.e.
where an integer from 0 to `num_classes` - 1 is used to
indicate which is the correct class). This should be of shape
(num_data,).
Returns:
Array of 1-of-K coded targets i.e. an array of shape
(num_data, num_classes) where for each row all elements are equal
to zero except for the column corresponding to the correct class
which is equal to one.
"""
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
return one_of_k_targets
class MNISTDataProvider(OneOfKDataProvider):
"""Data provider for MNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new MNIST data provider object.
Args:
which_set: One of 'train', 'valid' or 'test'. Determines which
portion of the MNIST data this object should provide.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
# check a valid which_set was provided
assert which_set in ['train', 'valid', 'test'], (
'Expected which_set to be either train, valid or test. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.num_classes = 10
# construct path to data using os.path.join to ensure the correct path
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# use a context-manager to ensure the files are properly closed after
# we are finished with them
with gzip.open(data_path) as f:
inputs, targets = cPickle.load(f)
# pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
class CIFAR10DataProvider(OneOfKDataProvider):
"""Data provider for CIFAR-10 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new CIFAR-10 data provider object.
Args:
which_set: One of 'train' or 'valid'. Determines which
portion of the CIFAR-10 data this object should provide.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
# check a valid which_set was provided
assert which_set in ['train', 'valid'], (
'Expected which_set to be either train or valid. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.num_classes = 10
# construct path to data using os.path.join to ensure the correct path
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'cifar-10-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
# label map gives strings corresponding to integer label targets
self.label_map = loaded['label_map']
# pass the loaded data to the parent class __init__
super(CIFAR10DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
class CIFAR100DataProvider(OneOfKDataProvider):
"""Data provider for CIFAR-100 object images."""
def __init__(self, which_set='train', use_coarse_targets=False,
batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new CIFAR-100 data provider object.
Args:
which_set: One of 'train' or 'valid'. Determines which
portion of the CIFAR-100 data this object should provide.
use_coarse_targets: Whether to use coarse 'superclass' labels as
targets instead of standard class labels.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
# check a valid which_set was provided
assert which_set in ['train', 'valid'], (
'Expected which_set to be either train or valid. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.use_coarse_targets = use_coarse_targets
self.num_classes = 20 if use_coarse_targets else 100
# construct path to data using os.path.join to ensure the correct path
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'cifar-100-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
targets = targets[:, 1] if use_coarse_targets else targets[:, 0]
# label map gives strings corresponding to integer label targets
self.label_map = (
loaded['coarse_label_map']
if use_coarse_targets else
loaded['fine_label_map']
)
# pass the loaded data to the parent class __init__
super(CIFAR100DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment