Skip to content

Instantly share code, notes, and snippets.

@manas-raj-shrestha
Created April 25, 2023 07:26
Show Gist options
  • Save manas-raj-shrestha/ad6760d5358c13bef6be5be470355484 to your computer and use it in GitHub Desktop.
Save manas-raj-shrestha/ad6760d5358c13bef6be5be470355484 to your computer and use it in GitHub Desktop.
### Python script to split a labeled image dataset into Train, Validation, and Test folders.
# Author: Evan Juras, EJ Technology Consultants
# Date: 4/10/21
# Randomly splits images to 80% train, 10% validation, and 10% test, and moves them to their respective folders.
# This script is intended to be used in the TFLite Object Detection Colab notebook here:
# https://colab.research.google.com/github/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/blob/master/Train_TFLite2_Object_Detction_Model.ipynb
import glob
from pathlib import Path
import random
import os
# Define paths to image folders
image_path = '/content/images/all'
train_path = '/content/images/train'
val_path = '/content/images/validation'
test_path = '/content/images/test'
# Get list of all images
jpg_file_list = [path for path in Path(image_path).rglob('*.jpg')]
jpg_file_list = [path for path in Path(image_path).rglob('*.jpeg')]
JPG_file_list = [path for path in Path(image_path).rglob('*.JPG')]
png_file_list = [path for path in Path(image_path).rglob('*.png')]
bmp_file_list = [path for path in Path(image_path).rglob('*.bmp')]
file_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list
file_num = len(file_list)
print('Total images: %d' % file_num)
# Determine number of files to move to each folder
train_percent = 0.8 # 80% of the files go to train
val_percent = 0.1 # 10% go to validation
test_percent = 0.1 # 10% go to test
train_num = int(file_num*train_percent)
val_num = int(file_num*val_percent)
test_num = file_num - train_num - val_num
print('Images moving to train: %d' % train_num)
print('Images moving to validation: %d' % val_num)
print('Images moving to test: %d' % test_num)
# Select 80% of files randomly and move them to train folder
for i in range(train_num):
move_me = random.choice(file_list)
fn = move_me.name
base_fn = move_me.stem
parent_path = move_me.parent
xml_fn = base_fn + '.xml'
os.rename(move_me, train_path+'/'+fn)
os.rename(os.path.join(parent_path,xml_fn),os.path.join(train_path,xml_fn))
file_list.remove(move_me)
# Select 10% of remaining files and move them to validation folder
for i in range(val_num):
move_me = random.choice(file_list)
fn = move_me.name
base_fn = move_me.stem
parent_path = move_me.parent
xml_fn = base_fn + '.xml'
os.rename(move_me, val_path+'/'+fn)
os.rename(os.path.join(parent_path,xml_fn),os.path.join(val_path,xml_fn))
file_list.remove(move_me)
# Move remaining files to test folder
for i in range(test_num):
move_me = random.choice(file_list)
fn = move_me.name
base_fn = move_me.stem
parent_path = move_me.parent
xml_fn = base_fn + '.xml'
os.rename(move_me, test_path+'/'+fn)
os.rename(os.path.join(parent_path,xml_fn),os.path.join(test_path,xml_fn))
file_list.remove(move_me)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment