Last active
January 25, 2022 10:00
-
-
Save mattroz/348fbed734fcaa2779fa4c80e384f3db to your computer and use it in GitHub Desktop.
Splits and moves data in specified directory to train/ and val/ directories (tested on images)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import numpy as np | |
from pathlib import Path | |
from math import floor | |
from typing import * | |
EXTENSIONS = {'.png', '.jpeg', '.jpg'} | |
logging.basicConfig(format='[%(asctime)s - %(name)s - %(levelname)s]\n - %(message)s', level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def trainval_split(path_to_data: Path, path_to_train: Path = None, path_to_val: Path = None, split_fraction: float = 0.2): | |
path_to_train = Path(path_to_data, 'train') if path_to_train is None else path_to_train | |
path_to_val = Path(path_to_data, 'val') if path_to_val is None else path_to_val | |
data = [elem for elem in path_to_data.iterdir() if elem.suffix in EXTENSIONS] | |
path_to_train.mkdir(exist_ok=True) | |
path_to_val.mkdir(exist_ok=True) | |
number_of_val_data = floor(len(data) * split_fraction) | |
val_data = np.random.choice(data, number_of_val_data, replace=False) | |
for elem in val_data: | |
data.remove(elem) | |
elem.rename(Path(path_to_val, elem.name)) | |
for elem in data: | |
elem.rename(Path(path_to_train, elem.name)) | |
logger.info(f"Train: {len(os.listdir(path_to_train))} files in {path_to_train}") | |
logger.info(f"Val: {len(os.listdir(path_to_val))} files in {path_to_val}") | |
def trainval_split_images_with_labels(path_to_data: Path, split_fraction: float = 0.2): | |
assert 'images' in [x.name for x in path_to_data.iterdir()], 'Directory with images not found' | |
assert 'labels' in [x.name for x in path_to_data.iterdir()], 'Directory with labels not found' | |
# Define paths | |
path_to_images = Path(path_to_data, 'images') | |
path_to_labels = Path(path_to_data, 'labels') | |
path_to_train_images = Path(path_to_data, 'train', 'images') | |
path_to_train_labels = Path(path_to_data, 'train', 'labels') | |
path_to_val_images = Path(path_to_data, 'val', 'images') | |
path_to_val_labels = Path(path_to_data, 'val', 'labels') | |
path_to_train_images.mkdir(parents=True, exist_ok=True) | |
path_to_train_labels.mkdir(parents=True, exist_ok=True) | |
path_to_val_images.mkdir(parents=True, exist_ok=True) | |
path_to_val_labels.mkdir(parents=True, exist_ok=True) | |
# Greate general pools of data | |
images = np.array(sorted([elem for elem in path_to_images.iterdir() if elem.suffix in EXTENSIONS])) | |
labels = np.array(sorted([elem for elem in path_to_labels.iterdir()])) | |
assert len(images) == len(labels), 'Images length doesnt math labels length' | |
number_of_val_data = floor(len(images) * split_fraction) | |
indices = list(range(len(images))) | |
# Select val indeces randomly from general pool of data | |
val_indices = np.random.choice(indices, number_of_val_data, replace=False) | |
val_images = images[val_indices] | |
val_labels = labels[val_indices] | |
# Remove selected val data from general pool | |
images = np.delete(images, val_indices) | |
labels = np.delete(labels, val_indices) | |
# Move data to its directories | |
for img, lbl in zip(val_images, val_labels): | |
img.rename(Path(path_to_val_images, img.name)) | |
lbl.rename(Path(path_to_val_labels, lbl.name)) | |
for img, lbl in zip(images, labels): | |
img.rename(Path(path_to_train_images, img.name)) | |
lbl.rename(Path(path_to_train_labels, lbl.name)) | |
logger.info(f"Train images: {len(os.listdir(path_to_train_images))} files in {path_to_train_images}") | |
logger.info(f"Train labels: {len(os.listdir(path_to_train_labels))} files in {path_to_train_labels}") | |
logger.info(f"Val images: {len(os.listdir(path_to_val_images))} files in {path_to_val_images}") | |
logger.info(f"Val labels: {len(os.listdir(path_to_val_labels))} files in {path_to_val_labels}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment