Skip to content

Instantly share code, notes, and snippets.

@xkumiyu
Last active May 12, 2023 23:14
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xkumiyu/c93222f2dce615f4b264a9e71f6d49e0 to your computer and use it in GitHub Desktop.
Save xkumiyu/c93222f2dce615f4b264a9e71f6d49e0 to your computer and use it in GitHub Desktop.
Download and Convert MNIST binary files to image files
import argparse
import gzip
import pathlib
import struct
import numpy as np
import pandas as pd
import requests
from PIL import Image
def donwload(urls, path):
path.mkdir(parents=True, exist_ok=True)
for url in urls:
filepath = path / pathlib.Path(url).name
if not filepath.exists():
res = requests.get(url)
if res.status_code == 200:
with open(filepath, 'wb') as f:
f.write(res.content)
def load(paths):
x_path, y_path = paths
with gzip.open(x_path) as fx, gzip.open(y_path) as fy:
fx.read(4)
fy.read(4)
N, = struct.unpack('>i', fy.read(4))
if N != struct.unpack('>i', fx.read(4))[0]:
raise RuntimeError('wrong pair of MNIST images and labels')
fx.read(8)
images = np.empty((N, 784), dtype=np.uint8)
labels = np.empty(N, dtype=np.uint8)
for i in range(N):
labels[i] = ord(fy.read(1))
for j in range(784):
images[i, j] = ord(fx.read(1))
return images, labels
def make_images(path, images, labels):
path.mkdir(parents=True, exist_ok=True)
for (i, image), label in zip(enumerate(images), labels):
filepath = path / '{}_{}.jpg'.format(label, i)
Image.fromarray(image.reshape(28, 28)).save(filepath)
def make_labellist(path, kind, labels):
path.mkdir(parents=True, exist_ok=True)
filepaths = [
'{}_{}.jpg'.format(label, i) for i, label in enumerate(labels)
]
df = pd.DataFrame({'name': filepaths, 'target': labels.tolist()})
df.to_csv(path / '{}.csv'.format(kind), index=False, header=False)
def main():
parser = argparse.ArgumentParser(
description='Download and Convert MNIST binary files to image files')
parser.add_argument('-p', '--path', type=pathlib.Path, default='./data')
parser.add_argument('-o', '--out', choices=['npz', 'jpg'], default='jpg')
args = parser.parse_args()
def pipeline(kind):
_kind = kind
if kind == 'test':
_kind = 't10k'
baseurl = 'http://yann.lecun.com/exdb/mnist'
urls = [
'{}/{}-images-idx3-ubyte.gz'.format(baseurl, _kind),
'{}/{}-labels-idx1-ubyte.gz'.format(baseurl, _kind)
]
donwload(urls, args.path / 'raw')
paths = [
args.path / 'raw' / '{}-images-idx3-ubyte.gz'.format(_kind),
args.path / 'raw' / '{}-labels-idx1-ubyte.gz'.format(_kind)
]
images, labels = load(paths)
if args.out == 'jpg':
path = args.path / 'processed'
make_images(path / 'images' / kind, images, labels)
make_labellist(path / 'labels', kind, labels)
else:
path = args.path / 'processed' / 'npz'
path.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
path / '{}.npz'.format(kind), x=images, y=labels)
print('Processing train data ...')
pipeline('train')
print('Processing test data ...')
pipeline('test')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment