Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Script to split data images for keras data_flow_from method
"""Build keras directory structure for flow_from_dir method"""
import os
import shutil
import numpy as np
from glob import glob
DATA_DIR = 'data/clean'
FILE_EXT_PTN = '*.jpg'
TARGET_DIR = 'data/shrine_temple2'
classes = ('shrine', 'temple')
seed = 0
ratio_test = 0.2
ratio_valid = 0.2
ratio_args = {'ratio_test': ratio_test, 'ratio_valid': ratio_valid}
def split_img_data_paths(img_data_dir, ratio_test=0.2, ratio_valid=0.0,
file_ext_ptn='*.jpg', seed=0, verbose=False):
assert (ratio_test >= 0.) & (ratio_valid >= 0.)
assert (ratio_test + ratio_valid) < 1.
paths_org = glob(os.path.join(img_data_dir, file_ext_ptn))
n_total = len(paths_org)
n_test = int(ratio_test * n_total)
n_valid = int(ratio_valid * n_total)
n_train = n_total - (n_test + n_valid)
paths_test = np.random.choice(paths_org, size=n_test, replace=False)
paths_not_test = set(paths_org) - set(paths_test)
paths_valid = np.random.choice(list(paths_not_test),
size=n_valid, replace=False)
paths_train = paths_not_test - set(paths_valid)
if verbose:
print('Data dir:', img_data_dir)
print('N total:', len(paths_org))
print('N train:', len(paths_train))
print('N valid:', len(paths_valid))
print('N test:', len(paths_test))
assert len(paths_org) == \
len(paths_train) + len(paths_valid) + len(paths_test)
if ratio_valid > 0.:
return list(paths_train), list(paths_test), list(paths_valid)
return list(paths_train), list(paths_test)
def build_data_img_structure(path_dir, classes):
md = lambda x: os.makedirs(x, exist_ok=True)
for data_type in ('train', 'validation', 'test'):
dir_datatype = os.path.join(path_dir, data_type)
if data_type == 'test':
md(os.path.join(dir_datatype, 'unknown'))
for class_type in classes:
md(os.path.join(dir_datatype, class_type))
def move_imgs_along_structure(path_dir, img_paths_dict, copy=True,
def _get_dst_path(dst_dir, src_path, class_name, add_class_prefix):
if add_class_prefix:
return os.path.join(dst_dir,
class_name + '_' + os.path.basename(src_path))
return dst_dir
if copy:
proc = shutil.copy2
proc = shutil.move
for class_nm, data_type_dict in img_paths_dict.items():
for data_type in ('train', 'validation', 'test'):
if data_type == 'test':
dst_dir = os.path.join(path_dir, data_type, 'unknown')
dst_dir = os.path.join(path_dir, data_type, class_nm)
for img in data_type_dict[data_type]:
proc(img, _get_dst_path(dst_dir, img, class_nm,
def main():
# Get image path
result_all = {}
for class_nm in classes:
src_data_dir = os.path.join(DATA_DIR, class_nm)
result = {}
result['train'], result['test'], result['validation'] =\
split_img_data_paths(src_data_dir, verbose=True, **ratio_args)
result_all[class_nm] = result.copy()
print('Split done.')
# Build data dir
build_data_img_structure(TARGET_DIR, classes=classes)
print('Build directories done.')
# Move files
move_imgs_along_structure(TARGET_DIR, result_all, copy=True)
print('Move files done.')
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.