Created
December 28, 2022 06:56
-
-
Save protortyp/6b2e58dc220c6b585fa2f9223bc283a7 to your computer and use it in GitHub Desktop.
Split datasets into train,val,test subdirectories
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Splits a dataset of png images into train, val, and test sets. | |
Creates the folders and randomly chooses files. | |
""" | |
import os | |
import random | |
import argparse | |
try: | |
import termcolor | |
TERMCOLOR_AVAILABLE = True | |
except: | |
TERMCOLOR_AVAILABLE = False | |
parser = argparse.ArgumentParser( | |
description="Splits a dataset of images into train, val, and test sets. Creates the folders and randomly chooses files.", | |
) | |
# root directory | |
parser.add_argument( | |
"dir", type=str, default=".", help="root directory that contains all files" | |
) | |
# dataset splits | |
parser.add_argument( | |
"--split", | |
type=int, | |
nargs=3, | |
default=[60, 30, 10], | |
help="train, val, test splits in percentage", | |
) | |
# file types | |
parser.add_argument( | |
"--types", | |
type=str, | |
nargs="+", | |
default=["png", "jpg", "jpeg"], | |
help="file types to include", | |
) | |
# Access the root directory and splits using the args object | |
args = parser.parse_args() | |
root_dir = args.dir | |
train_pct, val_pct, test_pct = args.split | |
file_types = args.types | |
# Check if the root directory exists | |
if not os.path.isdir(root_dir): | |
print(f"Error: directory '{root_dir}' does not exist.") | |
exit(1) | |
# Change to the root directory | |
os.chdir(root_dir) | |
# get the list of image files in the current directories | |
files = [f for f in os.listdir() if any(f.endswith(t) for t in file_types)] | |
# check if files even exist | |
if len(files) == 0: | |
if TERMCOLOR_AVAILABLE: | |
print(termcolor.colored("No files found in {}".format(root_dir), color="red")) | |
else: | |
print(f"No files found in {root_dir}") | |
exit(0) | |
# create the directories | |
os.makedirs("train", exist_ok=True) | |
os.makedirs("val", exist_ok=True) | |
os.makedirs("test", exist_ok=True) | |
# randomly shuffle list of image files | |
random.shuffle(files) | |
# calculate the number of files | |
num_train = int(len(files) * train_pct / 100) | |
num_val = int(len(files) * val_pct / 100) | |
num_test = len(files) - num_train - num_val | |
assert num_train + num_val + num_test == len(files) | |
# split the files | |
print("Moving test files") | |
for i in range(num_test): | |
os.rename(files[i], "test/" + files[i]) | |
print("Moving train files") | |
for i in range(num_test, num_test + num_train): | |
os.rename(files[i], "train/" + files[i]) | |
print("Moving val files") | |
for i in range(num_test + num_train, num_test + num_train + num_val): | |
os.rename(files[i], "val/" + files[i]) | |
if TERMCOLOR_AVAILABLE: | |
print( | |
termcolor.colored( | |
"Split finished. Moved {} files".format(len(files)), color="green" | |
) | |
) | |
else: | |
print(f"Split finished. Moved {len(files)} files") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment