Created
February 25, 2019 21:53
-
-
Save rsepassi/9cdd3e2521f908ade05bd2a88334ddcd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow_datasets as tfds | |
# Fetch the dataset directly | |
mnist = tfds.image.MNIST() | |
# or by string name | |
mnist = tfds.builder('mnist') | |
# Describe the dataset with DatasetInfo | |
assert mnist.info.features['image'].shape == (28, 28, 1) | |
assert mnist.info.features['label'].num_classes == 10 | |
assert mnist.info.splits['train'].num_examples == 60000 | |
# Download the data, prepare it, and write it to disk | |
mnist.download_and_prepare() | |
# Load data from disk as tf.data.Datasets | |
datasets = mnist.as_dataset() | |
train_dataset, test_dataset = datasets['train'], datasets['test'] | |
assert isinstance(train_dataset, tf.data.Dataset) | |
# And convert the Dataset to NumPy arrays if you'd like | |
for example in tfds.as_numpy(train_dataset): | |
image, label = example['image'], example['label'] | |
assert isinstance(image, np.array) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think line 24 should be changed to the following:
Currently, this code produces an error as follows: