Skip to content

Instantly share code, notes, and snippets.

@protortyp
Created December 28, 2022 06:56
Show Gist options
  • Save protortyp/6b2e58dc220c6b585fa2f9223bc283a7 to your computer and use it in GitHub Desktop.
Save protortyp/6b2e58dc220c6b585fa2f9223bc283a7 to your computer and use it in GitHub Desktop.
Split datasets into train,val,test subdirectories
"""
Splits a dataset of png images into train, val, and test sets.
Creates the folders and randomly chooses files.
"""
import os
import random
import argparse
try:
import termcolor
TERMCOLOR_AVAILABLE = True
except:
TERMCOLOR_AVAILABLE = False
parser = argparse.ArgumentParser(
description="Splits a dataset of images into train, val, and test sets. Creates the folders and randomly chooses files.",
)
# root directory
parser.add_argument(
"dir", type=str, default=".", help="root directory that contains all files"
)
# dataset splits
parser.add_argument(
"--split",
type=int,
nargs=3,
default=[60, 30, 10],
help="train, val, test splits in percentage",
)
# file types
parser.add_argument(
"--types",
type=str,
nargs="+",
default=["png", "jpg", "jpeg"],
help="file types to include",
)
# Access the root directory and splits using the args object
args = parser.parse_args()
root_dir = args.dir
train_pct, val_pct, test_pct = args.split
file_types = args.types
# Check if the root directory exists
if not os.path.isdir(root_dir):
print(f"Error: directory '{root_dir}' does not exist.")
exit(1)
# Change to the root directory
os.chdir(root_dir)
# get the list of image files in the current directories
files = [f for f in os.listdir() if any(f.endswith(t) for t in file_types)]
# check if files even exist
if len(files) == 0:
if TERMCOLOR_AVAILABLE:
print(termcolor.colored("No files found in {}".format(root_dir), color="red"))
else:
print(f"No files found in {root_dir}")
exit(0)
# create the directories
os.makedirs("train", exist_ok=True)
os.makedirs("val", exist_ok=True)
os.makedirs("test", exist_ok=True)
# randomly shuffle list of image files
random.shuffle(files)
# calculate the number of files
num_train = int(len(files) * train_pct / 100)
num_val = int(len(files) * val_pct / 100)
num_test = len(files) - num_train - num_val
assert num_train + num_val + num_test == len(files)
# split the files
print("Moving test files")
for i in range(num_test):
os.rename(files[i], "test/" + files[i])
print("Moving train files")
for i in range(num_test, num_test + num_train):
os.rename(files[i], "train/" + files[i])
print("Moving val files")
for i in range(num_test + num_train, num_test + num_train + num_val):
os.rename(files[i], "val/" + files[i])
if TERMCOLOR_AVAILABLE:
print(
termcolor.colored(
"Split finished. Moved {} files".format(len(files)), color="green"
)
)
else:
print(f"Split finished. Moved {len(files)} files")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment