Skip to content

Instantly share code, notes, and snippets.

@delta2323
Last active April 23, 2019 13:58
Show Gist options
  • Save delta2323/590112267d7413ffdd054622b91fb3cf to your computer and use it in GitHub Desktop.
Save delta2323/590112267d7413ffdd054622b91fb3cf to your computer and use it in GitHub Desktop.
How to create your own dataset and use it in Chainer

The goal of this document is to explain the basic mechanism of the dataset in Chainer and how to prepare customized dataset that can work with Trainer.

See the official document for the full detail of the specifications.

Intarface of dataset

In order to make the dataset be able to work with the Trainer, it must have two methods:

  • __getitem__, which is used for indexing like dataset[i] or slicing like dataset[i:j]
  • __len__ , which enables us to feed it to len()

Conversely, Trainer can handle any object as a dataset if it has __getitem__ and __len__.

Example of dataset

For example suppose we want use a numpy ndarray as a dataset:

dataset = numpy.random.uniform(-1, 1, (100, 50)).astype(numpy.float32)

Then, this dataset consists of 100 (=len(dataset)) samples and its i-th example is dataset[i], which is 1-dimensional array of length 50.

If we want to attach ground-truth labels, chainer.datasets.TupleDataset is suitable for this job. In the example above, each sample consists of a pair of a feature vector of length 50 and a integer label:

feature_vector = numpy.random.uniform(-1, 1, (100, 50)).astype(numpy.float32)
label = numpy.random.randint(0, 10, (100,)).astype(numpy.int32)
dataset = chainer.datasets.TupleDataset(feature_vector, label)

We have an utility function get_mnist, which is used in the official MNIST example. If get_mnist is called with withlabel = True, the returned value is the tuple of TupleDataset, one is for training and the other is for test. Each TupleDataset is a pair of a gray-scale image and a label (a single interger).

For large dataset

If your dataset is so large that it cannot be fitted to the memory, one way to create a dataset is to make an object (or Python class) that takes a list of paths to samples (e.g. in the __init__ method) and fetches samples lazily from the HDD in __getitem__ method. chainer.datasets.ImageDataset that Chainer offers officially takes this strategy.

DatasetMixin

It is sometimes cumbersome to implement __getitem__ because it must support not only indexing-by-integer but also slicing. In that case, users can create a dataset by inheriting chainer.dataset.DatasetMixin and implement get_example method, instead of implementing __getitem__ directly. get_example(i) should return i-th sample. It enables the dataset to index by integer (e.g. dataset[i]) but also more clever slicing (e.g. dataset[i:j]).

The implementation of chainer.datasets.SubDataset could be helpful as an example of DatasetMixin.

Note that if the object you want to treat as a dataset alreday supports slicing, you do not need to use DatasetMixin. Note also that sometimes slicing via get_example could be slower because it creates sub arrays by fetching a sample one by one via get_example. If your application needs better performance, you might need to implement __getitem__ directly.

@WinterKwei
Copy link

Thank you very much, this document helped me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment