Skip to content

Instantly share code, notes, and snippets.

@nullicorn
Forked from akesling/mnist.py
Last active December 13, 2017 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nullicorn/00365f94ecf71b5c8b3aa9875f2d2fb8 to your computer and use it in GitHub Desktop.
Save nullicorn/00365f94ecf71b5c8b3aa9875f2d2fb8 to your computer and use it in GitHub Desktop.
Python script for working with MNIST dataset. Minor edits to work with python3 and ascii_show function
import os
import struct
import numpy as np
"""
Source: https://gist.github.com/akesling/5358964
Loosely inspired by http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
which is GPL licensed.
"""
def read(dataset = "training", path = "."):
"""
Python function for importing the MNIST data set. It returns an iterator
of 2-tuples with the first element being the label and the second element
being a numpy.uint8 2D array of pixel data for the given image.
"""
if dataset is "training":
fname_img = os.path.join(path, 'train-images.idx3-ubyte')
fname_lbl = os.path.join(path, 'train-labels.idx1-ubyte')
elif dataset is "testing":
fname_img = os.path.join(path, 't10k-images.idx3-ubyte')
fname_lbl = os.path.join(path, 't10k-labels.idx1-ubyte')
else:
raise(ValueError, "dataset must be 'testing' or 'training'")
# Load everything in some numpy arrays
with open(fname_lbl, 'rb') as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
lbl = np.fromfile(flbl, dtype=np.int8)
with open(fname_img, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols)
get_img = lambda idx: (lbl[idx], img[idx])
# Create an iterator which returns each image in turn
for i in range(len(lbl)):
yield get_img(i)
def show(image):
"""
Render a given numpy.uint8 2D array of pixel data.
"""
from matplotlib import pyplot
import matplotlib as mpl
fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
imgplot = ax.imshow(image, cmap=mpl.cm.Greys)
imgplot.set_interpolation('nearest')
ax.xaxis.set_ticks_position('top')
ax.yaxis.set_ticks_position('left')
pyplot.show()
def ascii_show(image):
for y in image[1]:
row = ""
for x in y:
row += '{0: <4}'.format(x)
print(row)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment