Skip to content

Instantly share code, notes, and snippets.

@Unbinilium
Created February 27, 2022 06:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Unbinilium/1643031931876e4701d99a0a9983a3ac to your computer and use it in GitHub Desktop.
Save Unbinilium/1643031931876e4701d99a0a9983a3ac to your computer and use it in GitHub Desktop.
Datasets to MNIST
#!/usr/bin/env python3
import os, argparse, secrets
import numpy as np
from array import *
from pathlib import Path
from PIL import Image
def data_2_mnist(data_folder:Path, img_size:int, output_folder:Path):
print("Dataset folder ->", data_folder)
print("Image size ->", img_size, "*", img_size)
print("Output folder ->", output_folder)
categories = os.listdir(data_folder)
ind_categories = [os.path.isdir(os.path.join(data_folder, i)) for i in categories]
categories = [categories[i] for i, x in enumerate(ind_categories) if x]
data_set = []
data = array('B')
header = array('B')
label = array('B')
num_files = 0
num_cat = 0
width = img_size
height = img_size
for cat in categories:
cat_folder = os.path.join(data_folder, cat)
cat_files = os.listdir(cat_folder)
print('Using index', num_cat, 'as category', cat)
for file in cat_files:
num_files += 1
file_path = os.path.join(cat_folder, file)
img = Image.open(file_path)
img = img.resize(size=(width, height), resample=Image.LANCZOS)
img = np.asarray(img)
data_set.append((num_cat, img))
num_cat += 1
secure_random = secrets.SystemRandom()
for pair in secure_random.sample(data_set, len(data_set)):
label.append(np.uint8(pair[0]))
for x in range(0, width):
for y in range(0, height):
data.append(np.uint8(pair[1][x, y]))
hexval = "{0:#0{1}x}".format(num_files, 6)
header.extend([0, 0, 8, 1, 0, 0])
header.append(int('0x' + hexval[2:][:2], 16))
header.append(int('0x' + hexval[2:][2:], 16))
label = header + label
header.extend([0, 0, 0, width, 0, 0, 0, height])
header[3] = 3
data = header + data
output_file = open(os.path.join(output_folder, 'images.idx3-ubyte'), 'wb')
data.tofile(output_file)
output_file.close()
output_file = open(os.path.join(output_folder, 'labels.idx1-ubyte'), 'wb')
label.tofile(output_file)
output_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert jpg/png datasets to ubyte MNIST.')
parser.add_argument('-i', '--data_dir', type=Path, required=True, help='the path of datasets folder which have sub-folders contained images with its labels')
parser.add_argument('-s', '--img_size', type=int, required=False, default=28, help='the pixels of rect image in MNIST')
parser.add_argument('-o', '--out_dir', type=Path, required=True, help='the path of converted MNIST to store')
args = parser.parse_args()
data_2_mnist(args.data_dir, args.img_size, args.out_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment