Last active
October 14, 2017 22:21
-
-
Save charlesreid1/eefc22defc8c6bd07c6bd0ac222c9781 to your computer and use it in GitHub Desktop.
Example usage of Dataset objects for Fuel library. https://github.com/mila-udem/fuel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fuel.datasets import IndexableDataset | |
from collections import OrderedDict | |
import numpy | |
seed = 1234 | |
rng = numpy.random.RandomState(seed) | |
# Make some fake data | |
features = rng.randint(256, size=(8, 2, 2)) | |
targets = rng.randint(4, size=(8, 1)) | |
# Make a Dataset - in particular, an IndexableDataset | |
dataset = IndexableDataset( | |
indexables=OrderedDict([('features', features), ('targets', targets)]), | |
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')), | |
('targets', ('batch', 'index'))])) | |
# The main difference between | |
# IterableDataset and IndexableDataset | |
# is random access. | |
# | |
# This is accomplished by passing | |
# request argument to get_data() | |
state = dataset.open() | |
print("State is {}".format(state)) | |
print("NOTE: None state returned, because there is no state to maintain!") | |
print(dataset.get_data(state=state, request=[3,1,0])) | |
# Clean up | |
dataset.close(state=state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fuel.datasets import IterableDataset | |
from collections import OrderedDict | |
import numpy | |
seed = 1234 | |
rng = numpy.random.RandomState(seed) | |
# Make some fake data | |
features = rng.randint(256, size=(8, 2, 2)) | |
targets = rng.randint(4, size=(8, 1)) | |
# Make a Dataset - in particular, an IterableDataset | |
dataset = IterableDataset( | |
iterables=OrderedDict([('features', features), ('targets', targets)]), | |
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')), | |
('targets', ('batch', 'index'))])) | |
# How to access features of the dataset | |
print('Provided sources are {}.'.format(dataset.provides_sources)) | |
print('Sources are {}.'.format(dataset.sources)) | |
print('Axis labels are {}.'.format(dataset.axis_labels)) | |
print('Dataset contains {} examples.'.format(dataset.num_examples)) | |
# Print all available attributes | |
from pprint import pprint | |
pprint(dir(dataset)) | |
# Access the data | |
state = dataset.open() | |
while True: | |
try: | |
print(dataset.get_data(state=state)) | |
except StopIteration: | |
print('Iterator finished') | |
break | |
# Reset the accessor/state | |
state = dataset.reset(state=state) | |
print(dataset.get_data(state=state)) | |
# Clean up | |
dataset.close(state=state) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fuel.datasets import IndexableDataset | |
from fuel.schemes import ShuffledScheme | |
from collections import OrderedDict | |
import numpy | |
seed = 1234 | |
rng = numpy.random.RandomState(seed) | |
n = 32 | |
# Make some fake data | |
features = rng.randint(256, size=(n, 2, 2)) | |
targets = rng.randint(4, size=(n, 1)) | |
# Make a Dataset - in particular, an IndexableDataset | |
dataset = IndexableDataset( | |
indexables=OrderedDict([('features', features), ('targets', targets)]), | |
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')), | |
('targets', ('batch', 'index'))])) | |
state = dataset.open() | |
scheme = ShuffledScheme(examples=dataset.num_examples, batch_size=4) | |
# Use get_request_iterator() to generate requests | |
# in shuffled order using the ShuffledScheme. | |
for request in scheme.get_request_iterator(): | |
print(request) | |
print("\n") | |
for request in scheme.get_request_iterator(): | |
data = dataset.get_data(state=state, request=request) | |
print(data[0].shape, data[1].shape) | |
## Print inputs | |
#print(data[0]) | |
## Print outputs | |
#print(data[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment