Created
December 21, 2013 06:28
-
-
Save hyponymous/8066122 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# File: extract.py | |
# Author: Michael David Plotz | |
# Date: Fri Dec 20 21:01:41 PST 2013 | |
# | |
# Extract images and labels from MNIST database (http://yann.lecun.com/exdb/mnist/) | |
# | |
import argparse | |
import math | |
import json | |
import png | |
import struct | |
def chunks(l, n): | |
return [l[i:i+n] for i in range(0, len(l), n)] | |
def extract_labels(data, args): | |
label_count = struct.unpack('>I', data[4:8])[0] | |
print "extracting {} labels from {}...".format(label_count, args.filename) | |
labels = [struct.unpack('B', data[8 + i])[0] for i in range(label_count)] | |
with open("{}.json".format(args.filename), 'w') as f: | |
f.write(json.dumps(labels)) | |
def extract_images(data, args): | |
(image_count, rows, columns) = struct.unpack('>III', data[4:16]) | |
print "extracting {} images ({}x{}) from {}...".format(image_count, rows, columns, args.filename) | |
print len(data) | |
pixel_count = rows * columns | |
unpack_format = str(pixel_count) + 'B' | |
w = png.Writer(columns, rows, greyscale=True) | |
# format a format string (gross) -- pad with the right number of zeroes | |
outfile_format = "{}-{}{}{}.png".format("{}", "{:0", int(math.log(image_count, 10)) + 1, "}") | |
all_images = chunks(data[16:], pixel_count) | |
for i in range(image_count): | |
img_data = struct.unpack(unpack_format, all_images[i]) | |
with open(outfile_format.format(args.filename, i), 'w') as f: | |
w.write(f, chunks(img_data, columns)) | |
if i % 100 == 0: | |
print "." | |
def main(): | |
parser = argparse.ArgumentParser(description='Extract MNIST data files to PNG and JSON') | |
parser.add_argument('filename') | |
args = parser.parse_args() | |
with open(args.filename, 'r') as f: | |
data = f.read() | |
magic = struct.unpack('>I', data[0:4])[0] | |
if magic == 0x00000801: | |
extract_labels(data, args) | |
elif magic == 0x00000803: | |
extract_images(data, args) | |
else: | |
print "Expecting magic number 0x00000801 or 0x00000803 at start of file" | |
if __name__ == '__main__': main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment