Created
April 25, 2023 07:26
-
-
Save manas-raj-shrestha/ad6760d5358c13bef6be5be470355484 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Python script to split a labeled image dataset into Train, Validation, and Test folders. | |
# Author: Evan Juras, EJ Technology Consultants | |
# Date: 4/10/21 | |
# Randomly splits images to 80% train, 10% validation, and 10% test, and moves them to their respective folders. | |
# This script is intended to be used in the TFLite Object Detection Colab notebook here: | |
# https://colab.research.google.com/github/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/blob/master/Train_TFLite2_Object_Detction_Model.ipynb | |
import glob | |
from pathlib import Path | |
import random | |
import os | |
# Define paths to image folders | |
image_path = '/content/images/all' | |
train_path = '/content/images/train' | |
val_path = '/content/images/validation' | |
test_path = '/content/images/test' | |
# Get list of all images | |
jpg_file_list = [path for path in Path(image_path).rglob('*.jpg')] | |
jpg_file_list = [path for path in Path(image_path).rglob('*.jpeg')] | |
JPG_file_list = [path for path in Path(image_path).rglob('*.JPG')] | |
png_file_list = [path for path in Path(image_path).rglob('*.png')] | |
bmp_file_list = [path for path in Path(image_path).rglob('*.bmp')] | |
file_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list | |
file_num = len(file_list) | |
print('Total images: %d' % file_num) | |
# Determine number of files to move to each folder | |
train_percent = 0.8 # 80% of the files go to train | |
val_percent = 0.1 # 10% go to validation | |
test_percent = 0.1 # 10% go to test | |
train_num = int(file_num*train_percent) | |
val_num = int(file_num*val_percent) | |
test_num = file_num - train_num - val_num | |
print('Images moving to train: %d' % train_num) | |
print('Images moving to validation: %d' % val_num) | |
print('Images moving to test: %d' % test_num) | |
# Select 80% of files randomly and move them to train folder | |
for i in range(train_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, train_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(train_path,xml_fn)) | |
file_list.remove(move_me) | |
# Select 10% of remaining files and move them to validation folder | |
for i in range(val_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, val_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(val_path,xml_fn)) | |
file_list.remove(move_me) | |
# Move remaining files to test folder | |
for i in range(test_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, test_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(test_path,xml_fn)) | |
file_list.remove(move_me) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment