Skip to content

Instantly share code, notes, and snippets.

@aldro61
Created July 24, 2017 20:12
Show Gist options
  • Save aldro61/40233cb59a3acf725dde6abb617141d4 to your computer and use it in GitHub Desktop.
Save aldro61/40233cb59a3acf725dde6abb617141d4 to your computer and use it in GitHub Desktop.
Load the MNIST dataset into numpy arrays
"""
Load the MNIST dataset into numpy arrays
Author: Alexandre Drouin
License: BSD
"""
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
X_train = np.vstack([img.reshape(-1,) for img in mnist.train.images])
y_train = mnist.train.labels
X_test = np.vstack([img.reshape(-1,) for img in mnist.test.images])
y_test = mnist.test.labels
del mnist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment