Created
December 3, 2018 05:32
-
-
Save stas00/e7b4d95fb146a9d91afb6f80507ce476 to your computer and use it in GitHub Desktop.
convert mnist digits or fashion db into jpg image file dataset like in imagenet dataset train/valid/test main subfolders with class number subfolders acting as labels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib, PIL, random, os, gzip | |
import numpy as np | |
def load_mnist(path, kind='train'): | |
"""Load MNIST data from `path`""" | |
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) | |
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) | |
with gzip.open(labels_path, 'rb') as lbpath: | |
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) | |
with gzip.open(images_path, 'rb') as imgpath: | |
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784) | |
return images, labels | |
def save_mnist(path, images, labels): | |
p = Path(path) | |
p.mkdir(parents=True, exist_ok=True) | |
# prep 10 dirs | |
for l in range(10): (p/str(l)).mkdir(parents=True, exist_ok=True) | |
for i, (im,l) in enumerate(zip(images, labels)): | |
#print(i, im, l) | |
dest = p/str(l)/f"{i}.jpg" | |
im = im.reshape(28, 28) | |
im = PIL.Image.fromarray(im, mode='L') | |
with dest.open(mode='wb') as f: im.save(f) | |
def split_pct(images, labels, pct=0.8): | |
items = len(images) | |
idx = list(range(items)) | |
split = int(items*pct) | |
#print(idx, split) | |
random.shuffle(idx) | |
train_idx = idx[:split] | |
valid_idx = idx[split:] | |
return images[train_idx], labels[train_idx], images[valid_idx], labels[valid_idx] | |
def mnist_to_imagenet_format(): | |
# convert to imagenet image format | |
images, labels = load_mnist('data/raw', 'train') | |
# split 80% train / 20% valid | |
images_trn, labels_trn, images_val, labels_val = split_pct(images, labels, 0.8) | |
save_mnist('data/train', images_trn, labels_trn) | |
save_mnist('data/valid', images_val, labels_val) | |
# test | |
images, labels = load_mnist('data/raw', 't10k') | |
save_mnist('data/test', images, labels) | |
mnist_to_imagenet_format() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You are an angel! AN ABSOLUTE ANGEL!!!!