Last active
February 16, 2018 03:26
-
-
Save messefor/528b500178e1341f906dcee74ab3b87e to your computer and use it in GitHub Desktop.
Script to split data images for keras data_flow_from method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Build keras directory structure for flow_from_dir method""" | |
import os | |
import shutil | |
import numpy as np | |
from glob import glob | |
DATA_DIR = 'data/clean' | |
FILE_EXT_PTN = '*.jpg' | |
TARGET_DIR = 'data/shrine_temple2' | |
classes = ('shrine', 'temple') | |
seed = 0 | |
ratio_test = 0.2 | |
ratio_valid = 0.2 | |
ratio_args = {'ratio_test': ratio_test, 'ratio_valid': ratio_valid} | |
def split_img_data_paths(img_data_dir, ratio_test=0.2, ratio_valid=0.0, | |
file_ext_ptn='*.jpg', seed=0, verbose=False): | |
np.random.seed(seed) | |
assert (ratio_test >= 0.) & (ratio_valid >= 0.) | |
assert (ratio_test + ratio_valid) < 1. | |
paths_org = glob(os.path.join(img_data_dir, file_ext_ptn)) | |
n_total = len(paths_org) | |
n_test = int(ratio_test * n_total) | |
n_valid = int(ratio_valid * n_total) | |
n_train = n_total - (n_test + n_valid) | |
paths_test = np.random.choice(paths_org, size=n_test, replace=False) | |
paths_not_test = set(paths_org) - set(paths_test) | |
paths_valid = np.random.choice(list(paths_not_test), | |
size=n_valid, replace=False) | |
paths_train = paths_not_test - set(paths_valid) | |
if verbose: | |
print('Data dir:', img_data_dir) | |
print('N total:', len(paths_org)) | |
print('N train:', len(paths_train)) | |
print('N valid:', len(paths_valid)) | |
print('N test:', len(paths_test)) | |
assert len(paths_org) == \ | |
len(paths_train) + len(paths_valid) + len(paths_test) | |
if ratio_valid > 0.: | |
return list(paths_train), list(paths_test), list(paths_valid) | |
else: | |
return list(paths_train), list(paths_test) | |
def build_data_img_structure(path_dir, classes): | |
md = lambda x: os.makedirs(x, exist_ok=True) | |
md(path_dir) | |
for data_type in ('train', 'validation', 'test'): | |
dir_datatype = os.path.join(path_dir, data_type) | |
md(dir_datatype) | |
if data_type == 'test': | |
md(os.path.join(dir_datatype, 'unknown')) | |
else: | |
for class_type in classes: | |
md(os.path.join(dir_datatype, class_type)) | |
def move_imgs_along_structure(path_dir, img_paths_dict, copy=True, | |
add_class_prefix=True): | |
def _get_dst_path(dst_dir, src_path, class_name, add_class_prefix): | |
if add_class_prefix: | |
return os.path.join(dst_dir, | |
class_name + '_' + os.path.basename(src_path)) | |
else: | |
return dst_dir | |
if copy: | |
proc = shutil.copy2 | |
else: | |
proc = shutil.move | |
for class_nm, data_type_dict in img_paths_dict.items(): | |
for data_type in ('train', 'validation', 'test'): | |
if data_type == 'test': | |
dst_dir = os.path.join(path_dir, data_type, 'unknown') | |
else: | |
dst_dir = os.path.join(path_dir, data_type, class_nm) | |
for img in data_type_dict[data_type]: | |
proc(img, _get_dst_path(dst_dir, img, class_nm, | |
add_class_prefix)) | |
def main(): | |
# Get image path | |
result_all = {} | |
for class_nm in classes: | |
src_data_dir = os.path.join(DATA_DIR, class_nm) | |
result = {} | |
result['train'], result['test'], result['validation'] =\ | |
split_img_data_paths(src_data_dir, verbose=True, **ratio_args) | |
result_all[class_nm] = result.copy() | |
print('Split done.') | |
# Build data dir | |
build_data_img_structure(TARGET_DIR, classes=classes) | |
print('Build directories done.') | |
# Move files | |
move_imgs_along_structure(TARGET_DIR, result_all, copy=True) | |
print('Move files done.') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment