Skip to content

Instantly share code, notes, and snippets.

@ricsi98
Created November 24, 2021 19:34
Show Gist options
  • Save ricsi98/705b478de4ca87135204dbad975de154 to your computer and use it in GitHub Desktop.
Save ricsi98/705b478de4ca87135204dbad975de154 to your computer and use it in GitHub Desktop.
Splits the whole Final_Train GTSRB dataset into a train and test piece. Usage: `split_classes("./GTSRB/Final_Training/Images", "./train", "./test", 0.1)`
import os
import random
import pandas as pd
from shutil import copyfile
class DataSplitter:
def __init__(self, split_ratio):
self.split_ratio = split_ratio
def _shuffle_split(self, files):
f2 = [f for f in files]
random.shuffle(f2)
split_idx = max(1, int(self.split_ratio * len(f2)))
return f2[:split_idx], f2[split_idx:]
def shuffle_split_files(self, files):
base_files = list(set([f.split("_")[0] for f in files]))
a, b = self._shuffle_split(base_files)
A = [f for f in files if f.split("_")[0] in a]
B = [f for f in files if f.split("_")[0] in b]
return A, B
def split_classes(root_path, train_path, test_path, ratio):
splitter = DataSplitter(ratio)
x = 0
os.mkdir(train_path)
os.mkdir(test_path)
for c in range(0, 43):
class_name = format(c, '05d')
prefix = root_path + '/' + class_name + '/'
files = [f for f in os.listdir(prefix) if f[-4:] == ".ppm"]
test, train = splitter.shuffle_split_files(files)
print(f"Splitting class {class_name} train: {len(train)}, test: {len(test)}, sum: {len(files)}, ratio: {len(test) / len(train)}")
# create train/test directories
train_class_folder = train_path + '/' + class_name
test_class_folder = test_path + '/' + class_name
os.mkdir(train_class_folder)
os.mkdir(test_class_folder)
for image in train:
src = prefix + image
dst = train_class_folder + '/' + image
copyfile(src, dst)
for image in test:
src = prefix + image
dst = test_class_folder + '/' + image
copyfile(src, dst)
# copy corresponding .csv information
csv_name = 'GT-'+ class_name + '.csv'
csv_path = prefix + csv_name
df = pd.read_csv(csv_path, delimiter=";")
df_train = df[df["Filename"].isin(train)]
df_test = df[df["Filename"].isin(test)]
train_csv_path = train_class_folder + '/' + csv_name
test_csv_path = test_class_folder + '/' + csv_name
df_train.to_csv(train_csv_path, index=False, sep=";")
df_test.to_csv(test_csv_path, index=False, sep=";")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment