Skip to content

Instantly share code, notes, and snippets.

@charlesreid1
Last active October 14, 2017 22:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save charlesreid1/eefc22defc8c6bd07c6bd0ac222c9781 to your computer and use it in GitHub Desktop.
Save charlesreid1/eefc22defc8c6bd07c6bd0ac222c9781 to your computer and use it in GitHub Desktop.
Example usage of Dataset objects for Fuel library. https://github.com/mila-udem/fuel
from fuel.datasets import IndexableDataset
from collections import OrderedDict
import numpy
seed = 1234
rng = numpy.random.RandomState(seed)
# Make some fake data
features = rng.randint(256, size=(8, 2, 2))
targets = rng.randint(4, size=(8, 1))
# Make a Dataset - in particular, an IndexableDataset
dataset = IndexableDataset(
indexables=OrderedDict([('features', features), ('targets', targets)]),
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')),
('targets', ('batch', 'index'))]))
# The main difference between
# IterableDataset and IndexableDataset
# is random access.
#
# This is accomplished by passing
# request argument to get_data()
state = dataset.open()
print("State is {}".format(state))
print("NOTE: None state returned, because there is no state to maintain!")
print(dataset.get_data(state=state, request=[3,1,0]))
# Clean up
dataset.close(state=state)
from fuel.datasets import IterableDataset
from collections import OrderedDict
import numpy
seed = 1234
rng = numpy.random.RandomState(seed)
# Make some fake data
features = rng.randint(256, size=(8, 2, 2))
targets = rng.randint(4, size=(8, 1))
# Make a Dataset - in particular, an IterableDataset
dataset = IterableDataset(
iterables=OrderedDict([('features', features), ('targets', targets)]),
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')),
('targets', ('batch', 'index'))]))
# How to access features of the dataset
print('Provided sources are {}.'.format(dataset.provides_sources))
print('Sources are {}.'.format(dataset.sources))
print('Axis labels are {}.'.format(dataset.axis_labels))
print('Dataset contains {} examples.'.format(dataset.num_examples))
# Print all available attributes
from pprint import pprint
pprint(dir(dataset))
# Access the data
state = dataset.open()
while True:
try:
print(dataset.get_data(state=state))
except StopIteration:
print('Iterator finished')
break
# Reset the accessor/state
state = dataset.reset(state=state)
print(dataset.get_data(state=state))
# Clean up
dataset.close(state=state)
from fuel.datasets import IndexableDataset
from fuel.schemes import ShuffledScheme
from collections import OrderedDict
import numpy
seed = 1234
rng = numpy.random.RandomState(seed)
n = 32
# Make some fake data
features = rng.randint(256, size=(n, 2, 2))
targets = rng.randint(4, size=(n, 1))
# Make a Dataset - in particular, an IndexableDataset
dataset = IndexableDataset(
indexables=OrderedDict([('features', features), ('targets', targets)]),
axis_labels=OrderedDict([('features', ('batch', 'height', 'width')),
('targets', ('batch', 'index'))]))
state = dataset.open()
scheme = ShuffledScheme(examples=dataset.num_examples, batch_size=4)
# Use get_request_iterator() to generate requests
# in shuffled order using the ShuffledScheme.
for request in scheme.get_request_iterator():
print(request)
print("\n")
for request in scheme.get_request_iterator():
data = dataset.get_data(state=state, request=request)
print(data[0].shape, data[1].shape)
## Print inputs
#print(data[0])
## Print outputs
#print(data[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment