Skip to content

Instantly share code, notes, and snippets.

@stas00
Created December 3, 2018 05:32
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save stas00/e7b4d95fb146a9d91afb6f80507ce476 to your computer and use it in GitHub Desktop.
Save stas00/e7b4d95fb146a9d91afb6f80507ce476 to your computer and use it in GitHub Desktop.
convert mnist digits or fashion db into jpg image file dataset like in imagenet dataset train/valid/test main subfolders with class number subfolders acting as labels
import pathlib, PIL, random, os, gzip
import numpy as np
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784)
return images, labels
def save_mnist(path, images, labels):
p = Path(path)
p.mkdir(parents=True, exist_ok=True)
# prep 10 dirs
for l in range(10): (p/str(l)).mkdir(parents=True, exist_ok=True)
for i, (im,l) in enumerate(zip(images, labels)):
#print(i, im, l)
dest = p/str(l)/f"{i}.jpg"
im = im.reshape(28, 28)
im = PIL.Image.fromarray(im, mode='L')
with dest.open(mode='wb') as f: im.save(f)
def split_pct(images, labels, pct=0.8):
items = len(images)
idx = list(range(items))
split = int(items*pct)
#print(idx, split)
random.shuffle(idx)
train_idx = idx[:split]
valid_idx = idx[split:]
return images[train_idx], labels[train_idx], images[valid_idx], labels[valid_idx]
def mnist_to_imagenet_format():
# convert to imagenet image format
images, labels = load_mnist('data/raw', 'train')
# split 80% train / 20% valid
images_trn, labels_trn, images_val, labels_val = split_pct(images, labels, 0.8)
save_mnist('data/train', images_trn, labels_trn)
save_mnist('data/valid', images_val, labels_val)
# test
images, labels = load_mnist('data/raw', 't10k')
save_mnist('data/test', images, labels)
mnist_to_imagenet_format()
@EvanMarie
Copy link

You are an angel! AN ABSOLUTE ANGEL!!!!

@stas00
Copy link
Author

stas00 commented Nov 14, 2022

thank you for the kind words, @EvanMarie

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment