Skip to content

Instantly share code, notes, and snippets.

@beatobongco
Created September 16, 2018 14:00
Show Gist options
  • Save beatobongco/e66dde2568bafb68d25b3712753a09e4 to your computer and use it in GitHub Desktop.
Save beatobongco/e66dde2568bafb68d25b3712753a09e4 to your computer and use it in GitHub Desktop.
Split datasets
import shutil
import random
from pathlib import Path
from typing import List
def dataset_splitter(output_dir: str = 'output', classes: List[str] = [], num_train=0, num_validation=0):
"""In the current directory, take files of a class and randomly copy a certain number of training examples
and validation examples into a new output directory.
TODO: add support for classes in folders already to begin with
"""
nums = {
'train': num_train,
'validation': num_validation
}
in_path = Path.cwd()
# check for a sibling folder of our in_path named 1000
out_path = in_path / '..' / output_dir
# create the path if it doesn't exist
out_path.mkdir(exist_ok=True)
iterators = {}
# create iterators so we don't end up using the same images for train and test
# this is inefficient because casting to list will hold the paths contents in memory
for cls in classes:
images = list(in_path.glob('{0}*'.format(cls)))
random.shuffle(images)
iterators[cls] = iter(images)
for dataset_type in ('train', 'validation'):
for cls in classes:
class_directory = out_path / dataset_type / cls
class_directory.mkdir(parents=True, exist_ok=True)
for index, image_path in enumerate(iterators[cls]):
print(index, str(class_directory))
if index == nums[dataset_type]:
break
shutil.copy(str(image_path), str(class_directory))
"""
OUTPUT:
train/
cat/
...1k random cat images
dog/
...
validation/
cat/
...400 random cat images
dog/
...
"""
dataset_splitter(output_dir='cats_vs_dogs_1k', classes=['cat', 'dog'], num_train=1000, num_validation=400)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment