Created
January 17, 2018 22:16
-
-
Save mrcoles/3b5c536a393b31f08de1546c91ab2660 to your computer and use it in GitHub Desktop.
Convert the MNIST CSV dataset from Kaggle to png images
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import csv | |
import os | |
import pathlib | |
import imageio | |
import numpy as np | |
IMG_WIDTH = 28 | |
IMG_HEIGHT = 28 | |
def run(infile, outdir, label=None): | |
_make_dir(outdir) | |
# parse file | |
reader = csv.DictReader(infile) | |
fieldnames = None | |
for i, row in enumerate(reader): | |
if fieldnames is None: | |
fieldnames = [x for x in reader.fieldnames if x != label] | |
label_val = row.pop(label) if label else None | |
pixels = [row[x] for x in fieldnames] | |
array = np.array(pixels, dtype=np.uint8) | |
array = array.reshape((IMG_WIDTH, IMG_HEIGHT)) | |
# print(label_val) | |
# print(pixels) | |
# print(array) | |
path_parts = [outdir] | |
if label: | |
path_parts.append(label_val) | |
_make_dir(os.path.join(*path_parts)) | |
path_parts.append(f'{i:05}.png') | |
filename = os.path.join(*path_parts) | |
# print(filename) | |
imageio.imwrite(filename, array) | |
_EXISTING_PATHS = set() | |
def _make_dir(dirpath): | |
if dirpath not in _EXISTING_PATHS and not os.path.exists(dirpath): | |
# mkdir -p | |
pathlib.Path(dirpath).mkdir(parents=True, exist_ok=True) | |
_EXISTING_PATHS.add(dirpath) | |
def _make_img(grays): | |
pass | |
## | |
def main(): | |
import argparse | |
import sys | |
parser = argparse.ArgumentParser(description='CSV manipulator') | |
parser.add_argument('infile', nargs='?', type=argparse.FileType('r'), default=sys.stdin, | |
help='CSV file to parse') | |
parser.add_argument('outdir', type=str, help='path to output directory') | |
parser.add_argument('--label', '-', type=str, default=None, | |
help='optional label column name (if training data)') | |
args = parser.parse_args() | |
run(args.infile, args.outdir, args.label) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I cut up the CVS train file into 90% a train file and 10% a valid file. So the csv files were in:
And I ran the following code: