Created
March 7, 2019 16:34
-
-
Save tamanobi/b059a8060e73ece9e6d47936e126c4c4 to your computer and use it in GitHub Desktop.
ディレクトリに入ったjpg, xml, txtを一定の割合でtrain、testディレクトリに移動する(inspired by https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding:utf-8 | |
from pathlib import Path | |
import sys | |
import argparse | |
from shutil import move | |
import random | |
from logging import getLogger, DEBUG, INFO, StreamHandler | |
logger = getLogger(__name__) | |
logger.setLevel(level=DEBUG) | |
logger.addHandler(StreamHandler(sys.stdout)) | |
def related_paths(sample_path): | |
candidates = [ | |
sample_path.with_suffix('.jpg'), | |
sample_path.with_suffix('.xml'), | |
sample_path.with_suffix('.txt') | |
] | |
return filter(lambda p: p.exists(), candidates) | |
if __name__ == '__main__': | |
"""ひとつのディレクトリの中身を学習用とテスト用に振り分ける""" | |
parser = argparse.ArgumentParser(description='') | |
parser.add_argument('--input_dir', type=str, help='input directory') | |
parser.add_argument('--output_test_dir', type=str, default='test', help='output test directory') | |
parser.add_argument('--output_train_dir', type=str, default='train', help='output train directory') | |
parser.add_argument('--test_size', type=float, default=0.25, help='output train directory') | |
parser.add_argument('--random_state', type=int, default=0, help='output train directory') | |
parser.add_argument('--dryrun', action="store_true", default=False, help='dryrun') | |
args = parser.parse_args() | |
random.seed(args.random_state) | |
input_dir = Path(args.input_dir) | |
output_test_dir = Path(args.output_test_dir) | |
output_train_dir = Path(args.output_train_dir) | |
if input_dir.exists() and output_test_dir.exists() and output_train_dir.exists(): | |
pass | |
else: | |
logger.error('some directories are not found') | |
exit() | |
RAND_MAX = 1000 | |
for jpg in input_dir.glob("*.jpg"): | |
dest_path = output_train_dir | |
if random.randint(1, RAND_MAX) <= RAND_MAX * args.test_size: | |
dest_path = output_test_dir | |
else: | |
dest_path = output_train_dir | |
# 関連ファイルをすべて移動 | |
for src_path in related_paths(jpg): | |
label = '' | |
if args.dryrun: | |
label = '[dryrun]' | |
logger.info(f"{label}move {str(src_path)} {str(dest_path / src_path.name)}") | |
if args.dryrun is False: | |
move(src_path, dest_path / src_path.name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment