Skip to content

Instantly share code, notes, and snippets.

@hyponymous
Created December 21, 2013 06:28
Show Gist options
  • Save hyponymous/8066122 to your computer and use it in GitHub Desktop.
Save hyponymous/8066122 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
#
# File: extract.py
# Author: Michael David Plotz
# Date: Fri Dec 20 21:01:41 PST 2013
#
# Extract images and labels from MNIST database (http://yann.lecun.com/exdb/mnist/)
#
import argparse
import math
import json
import png
import struct
def chunks(l, n):
return [l[i:i+n] for i in range(0, len(l), n)]
def extract_labels(data, args):
label_count = struct.unpack('>I', data[4:8])[0]
print "extracting {} labels from {}...".format(label_count, args.filename)
labels = [struct.unpack('B', data[8 + i])[0] for i in range(label_count)]
with open("{}.json".format(args.filename), 'w') as f:
f.write(json.dumps(labels))
def extract_images(data, args):
(image_count, rows, columns) = struct.unpack('>III', data[4:16])
print "extracting {} images ({}x{}) from {}...".format(image_count, rows, columns, args.filename)
print len(data)
pixel_count = rows * columns
unpack_format = str(pixel_count) + 'B'
w = png.Writer(columns, rows, greyscale=True)
# format a format string (gross) -- pad with the right number of zeroes
outfile_format = "{}-{}{}{}.png".format("{}", "{:0", int(math.log(image_count, 10)) + 1, "}")
all_images = chunks(data[16:], pixel_count)
for i in range(image_count):
img_data = struct.unpack(unpack_format, all_images[i])
with open(outfile_format.format(args.filename, i), 'w') as f:
w.write(f, chunks(img_data, columns))
if i % 100 == 0:
print "."
def main():
parser = argparse.ArgumentParser(description='Extract MNIST data files to PNG and JSON')
parser.add_argument('filename')
args = parser.parse_args()
with open(args.filename, 'r') as f:
data = f.read()
magic = struct.unpack('>I', data[0:4])[0]
if magic == 0x00000801:
extract_labels(data, args)
elif magic == 0x00000803:
extract_images(data, args)
else:
print "Expecting magic number 0x00000801 or 0x00000803 at start of file"
if __name__ == '__main__': main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment