Skip to content

Instantly share code, notes, and snippets.

@tegg89
Created October 31, 2018 05:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tegg89/7375cb211953800bd708b71d073f6a43 to your computer and use it in GitHub Desktop.
Save tegg89/7375cb211953800bd708b71d073f6a43 to your computer and use it in GitHub Desktop.
MNIST data
from sklearn.datasets import fetch_mldata
import urllib
import scipy
try:
mnist = fetch_mldata('MNIST original')
except urllib.error.HTTPError as ex:
print("Could not download MNIST data from mldata.org, trying alternative...")
# Alternative method to load MNIST, if mldata.org is down
from scipy.io import loadmat
mnist_alternative_url = "https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat"
mnist_path = "./mnist-original.mat"
response = urllib.request.urlopen(mnist_alternative_url)
with open(mnist_path, "wb") as f:
content = response.read()
f.write(content)
mnist_raw = loadmat(mnist_path)
mnist = {
"data": mnist_raw["data"].T,
"target": mnist_raw["label"][0],
"COL_NAMES": ["label", "data"],
"DESCR": "mldata.org dataset: mnist-original",
}
print("Success!")
def build_batches(x, n):
x = np.asarray(x)
m = (x.shape[0] // n) * n
return x[:m].reshape(-1, n, *x.shape[1:])
def get_mnist32_batches(batch_size, data_format='channels_first'):
channel_index = 1 if data_format == 'channels_first' else 3
# mnist = fetch_mldata('MNIST original')
data_x = mnist['data'].reshape(-1,28,28).astype(np.float32) / 255.
data_x = np.pad(data_x, ((0,0), (2,2), (2,2)), mode='constant')
data_x = np.expand_dims(data_x, channel_index)
data_y = mnist['target']
indices = np.arange(len(data_x))
np.random.shuffle(indices)
y_batches = build_batches(data_y[indices], batch_size)
x_batches = build_batches(data_x[indices], batch_size)
return x_batches, y_batches
x_batches, y_batches = get_mnist32_batches(args['batch_size'])
x_batches = torch.FloatTensor(x_batches).to(args['device'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment