Skip to content

Instantly share code, notes, and snippets.

@matsub
Created June 29, 2016 02:18
Show Gist options
  • Save matsub/206a1dac75093d74d8ae2ab9c5a2ae35 to your computer and use it in GitHub Desktop.
Save matsub/206a1dac75093d74d8ae2ab9c5a2ae35 to your computer and use it in GitHub Desktop.
A parser for MNIST handwritten digits dataset. see http://yann.lecun.com/exdb/mnist/.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import struct
class Image:
def __init__(self, dir='./'):
self.train_files = {
'images': os.path.join(dir, 'train-images-idx3-ubyte'),
'labels': os.path.join(dir, 'train-labels-idx1-ubyte')
}
self.test_files = {
'images': os.path.join(dir, 't10k-images-idx3-ubyte'),
'labels': os.path.join(dir, 't10k-labels-idx1-ubyte')
}
@property
def train(self):
path = self.train_files
return self._get_dataset(path)
@property
def test(self):
path = self.test_files
return self._get_dataset(path)
def _get_dataset(self, path):
images = self._load_images(path['images'])
labels = self._load_labels(path['labels'])
for image, label in zip(images, labels):
yield image, label
def _load_images(self, fname):
f = open(fname, 'rb')
header = struct.unpack('>4i', f.read(16))
magic, size, width, height = header
if magic != 2051:
raise RuntimeError("'%s' is not an MNIST image set." % fname)
chunk = width * height
for _ in range(size):
img = struct.unpack('>%dB' % chunk, f.read(chunk))
yield img, width, height
f.close()
def _load_labels(self, fname):
f = open(fname, 'rb')
header = struct.unpack('>2i', f.read(8))
magic, size = header
if magic != 2049:
raise RuntimeError("'%s' is not an MNIST label set." % fname)
for label in struct.unpack('>%dB' % size, f.read()):
yield label
f.close()
from mnist import Image
dataset = Image('path/to/MNIST-Dataset')
for (img, width, height), label in dataset.train:
print(img, width, height, label)
for (img, width, height), label in dataset.test:
print(img, width, height, label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment