Skip to content

Instantly share code, notes, and snippets.

@tamanobi
Created March 7, 2019 16:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tamanobi/b059a8060e73ece9e6d47936e126c4c4 to your computer and use it in GitHub Desktop.
Save tamanobi/b059a8060e73ece9e6d47936e126c4c4 to your computer and use it in GitHub Desktop.
ディレクトリに入ったjpg, xml, txtを一定の割合でtrain、testディレクトリに移動する(inspired by https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)
# coding:utf-8
from pathlib import Path
import sys
import argparse
from shutil import move
import random
from logging import getLogger, DEBUG, INFO, StreamHandler
logger = getLogger(__name__)
logger.setLevel(level=DEBUG)
logger.addHandler(StreamHandler(sys.stdout))
def related_paths(sample_path):
candidates = [
sample_path.with_suffix('.jpg'),
sample_path.with_suffix('.xml'),
sample_path.with_suffix('.txt')
]
return filter(lambda p: p.exists(), candidates)
if __name__ == '__main__':
"""ひとつのディレクトリの中身を学習用とテスト用に振り分ける"""
parser = argparse.ArgumentParser(description='')
parser.add_argument('--input_dir', type=str, help='input directory')
parser.add_argument('--output_test_dir', type=str, default='test', help='output test directory')
parser.add_argument('--output_train_dir', type=str, default='train', help='output train directory')
parser.add_argument('--test_size', type=float, default=0.25, help='output train directory')
parser.add_argument('--random_state', type=int, default=0, help='output train directory')
parser.add_argument('--dryrun', action="store_true", default=False, help='dryrun')
args = parser.parse_args()
random.seed(args.random_state)
input_dir = Path(args.input_dir)
output_test_dir = Path(args.output_test_dir)
output_train_dir = Path(args.output_train_dir)
if input_dir.exists() and output_test_dir.exists() and output_train_dir.exists():
pass
else:
logger.error('some directories are not found')
exit()
RAND_MAX = 1000
for jpg in input_dir.glob("*.jpg"):
dest_path = output_train_dir
if random.randint(1, RAND_MAX) <= RAND_MAX * args.test_size:
dest_path = output_test_dir
else:
dest_path = output_train_dir
# 関連ファイルをすべて移動
for src_path in related_paths(jpg):
label = ''
if args.dryrun:
label = '[dryrun]'
logger.info(f"{label}move {str(src_path)} {str(dest_path / src_path.name)}")
if args.dryrun is False:
move(src_path, dest_path / src_path.name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment