Skip to content

Instantly share code, notes, and snippets.

@mattroz
Last active January 25, 2022 10:00
Show Gist options
  • Save mattroz/348fbed734fcaa2779fa4c80e384f3db to your computer and use it in GitHub Desktop.
Save mattroz/348fbed734fcaa2779fa4c80e384f3db to your computer and use it in GitHub Desktop.
Splits and moves data in specified directory to train/ and val/ directories (tested on images)
import logging
import os
import numpy as np
from pathlib import Path
from math import floor
from typing import *
EXTENSIONS = {'.png', '.jpeg', '.jpg'}
logging.basicConfig(format='[%(asctime)s - %(name)s - %(levelname)s]\n - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def trainval_split(path_to_data: Path, path_to_train: Path = None, path_to_val: Path = None, split_fraction: float = 0.2):
path_to_train = Path(path_to_data, 'train') if path_to_train is None else path_to_train
path_to_val = Path(path_to_data, 'val') if path_to_val is None else path_to_val
data = [elem for elem in path_to_data.iterdir() if elem.suffix in EXTENSIONS]
path_to_train.mkdir(exist_ok=True)
path_to_val.mkdir(exist_ok=True)
number_of_val_data = floor(len(data) * split_fraction)
val_data = np.random.choice(data, number_of_val_data, replace=False)
for elem in val_data:
data.remove(elem)
elem.rename(Path(path_to_val, elem.name))
for elem in data:
elem.rename(Path(path_to_train, elem.name))
logger.info(f"Train: {len(os.listdir(path_to_train))} files in {path_to_train}")
logger.info(f"Val: {len(os.listdir(path_to_val))} files in {path_to_val}")
def trainval_split_images_with_labels(path_to_data: Path, split_fraction: float = 0.2):
assert 'images' in [x.name for x in path_to_data.iterdir()], 'Directory with images not found'
assert 'labels' in [x.name for x in path_to_data.iterdir()], 'Directory with labels not found'
# Define paths
path_to_images = Path(path_to_data, 'images')
path_to_labels = Path(path_to_data, 'labels')
path_to_train_images = Path(path_to_data, 'train', 'images')
path_to_train_labels = Path(path_to_data, 'train', 'labels')
path_to_val_images = Path(path_to_data, 'val', 'images')
path_to_val_labels = Path(path_to_data, 'val', 'labels')
path_to_train_images.mkdir(parents=True, exist_ok=True)
path_to_train_labels.mkdir(parents=True, exist_ok=True)
path_to_val_images.mkdir(parents=True, exist_ok=True)
path_to_val_labels.mkdir(parents=True, exist_ok=True)
# Greate general pools of data
images = np.array(sorted([elem for elem in path_to_images.iterdir() if elem.suffix in EXTENSIONS]))
labels = np.array(sorted([elem for elem in path_to_labels.iterdir()]))
assert len(images) == len(labels), 'Images length doesnt math labels length'
number_of_val_data = floor(len(images) * split_fraction)
indices = list(range(len(images)))
# Select val indeces randomly from general pool of data
val_indices = np.random.choice(indices, number_of_val_data, replace=False)
val_images = images[val_indices]
val_labels = labels[val_indices]
# Remove selected val data from general pool
images = np.delete(images, val_indices)
labels = np.delete(labels, val_indices)
# Move data to its directories
for img, lbl in zip(val_images, val_labels):
img.rename(Path(path_to_val_images, img.name))
lbl.rename(Path(path_to_val_labels, lbl.name))
for img, lbl in zip(images, labels):
img.rename(Path(path_to_train_images, img.name))
lbl.rename(Path(path_to_train_labels, lbl.name))
logger.info(f"Train images: {len(os.listdir(path_to_train_images))} files in {path_to_train_images}")
logger.info(f"Train labels: {len(os.listdir(path_to_train_labels))} files in {path_to_train_labels}")
logger.info(f"Val images: {len(os.listdir(path_to_val_images))} files in {path_to_val_images}")
logger.info(f"Val labels: {len(os.listdir(path_to_val_labels))} files in {path_to_val_labels}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment