Skip to content

Instantly share code, notes, and snippets.

@pbloem
Last active April 20, 2023 17:15
Show Gist options
  • Save pbloem/e2a46efe5b1fd4c098cd249d8f60d2c2 to your computer and use it in GitHub Desktop.
Save pbloem/e2a46efe5b1fd4c098cd249d8f60d2c2 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
from urllib import request
import gzip
import pickle
import os
def load_mnist(final=False, flatten=True, verbose=False, normalize=True):
"""
Load the MNIST data.
:param final: If true, return the canonical test/train split. If false, split some validation data from the training
data and keep the test data hidden.
:param flatten: If true, each instance is flattened into a vector, so that the data is returns as a matrix with 768
columns. If false, the data is returned as a 3-tensor preserving each image as a matrix.
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training
data and the corresponding classification labels as a numpy integer array. The second contains the test/validation
data in the same format. The last integer contains the number of classes (this is always 2 for this function).
"""
if not os.path.isfile('mnist.pkl'):
init()
xtrain, ytrain, xtest, ytest = load()
xtrain, ytrain, xtest, ytest = torch.from_numpy(xtrain), torch.from_numpy(ytrain), torch.from_numpy(xtest), torch.from_numpy(ytest)
xtl, xsl = xtrain.shape[0], xtest.shape[0]
if normalize:
xtrain = xtrain.to(torch.float) / 255.
xtest = xtest.to(torch.float) / 255.
if flatten:
xtrain = xtrain.reshape(xtl, -1)
xtest = xtest.reshape(xsl, -1)
if not final:
return (xtrain[:-5000], ytrain[:-5000]), (xtrain[-5000:], ytrain[-5000:]), 10
return (xtrain, ytrain), (xtest, ytest), 10
# Numpy-only MNIST loader. Courtesy of Hyeonseok Jung
# https://github.com/hsjeong5/MNIST-for-Numpy
filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]
def download_mnist():
base_url = "http://yann.lecun.com/exdb/mnist/"
for name in filename:
print("Downloading "+name[1]+"...")
request.urlretrieve(base_url+name[1], name[1])
print("Download complete.")
def save_mnist():
mnist = {}
for name in filename[:2]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 1, 28, 28)
for name in filename[-2:]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open("mnist.pkl", 'wb') as f:
pickle.dump(mnist,f)
print("Save complete.")
def init():
download_mnist()
save_mnist()
def load():
with open("mnist.pkl",'rb') as f:
mnist = pickle.load(f)
return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment