Created
November 24, 2021 19:34
-
-
Save ricsi98/705b478de4ca87135204dbad975de154 to your computer and use it in GitHub Desktop.
Splits the whole Final_Train GTSRB dataset into a train and test piece. Usage: `split_classes("./GTSRB/Final_Training/Images", "./train", "./test", 0.1)`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import pandas as pd | |
from shutil import copyfile | |
class DataSplitter: | |
def __init__(self, split_ratio): | |
self.split_ratio = split_ratio | |
def _shuffle_split(self, files): | |
f2 = [f for f in files] | |
random.shuffle(f2) | |
split_idx = max(1, int(self.split_ratio * len(f2))) | |
return f2[:split_idx], f2[split_idx:] | |
def shuffle_split_files(self, files): | |
base_files = list(set([f.split("_")[0] for f in files])) | |
a, b = self._shuffle_split(base_files) | |
A = [f for f in files if f.split("_")[0] in a] | |
B = [f for f in files if f.split("_")[0] in b] | |
return A, B | |
def split_classes(root_path, train_path, test_path, ratio): | |
splitter = DataSplitter(ratio) | |
x = 0 | |
os.mkdir(train_path) | |
os.mkdir(test_path) | |
for c in range(0, 43): | |
class_name = format(c, '05d') | |
prefix = root_path + '/' + class_name + '/' | |
files = [f for f in os.listdir(prefix) if f[-4:] == ".ppm"] | |
test, train = splitter.shuffle_split_files(files) | |
print(f"Splitting class {class_name} train: {len(train)}, test: {len(test)}, sum: {len(files)}, ratio: {len(test) / len(train)}") | |
# create train/test directories | |
train_class_folder = train_path + '/' + class_name | |
test_class_folder = test_path + '/' + class_name | |
os.mkdir(train_class_folder) | |
os.mkdir(test_class_folder) | |
for image in train: | |
src = prefix + image | |
dst = train_class_folder + '/' + image | |
copyfile(src, dst) | |
for image in test: | |
src = prefix + image | |
dst = test_class_folder + '/' + image | |
copyfile(src, dst) | |
# copy corresponding .csv information | |
csv_name = 'GT-'+ class_name + '.csv' | |
csv_path = prefix + csv_name | |
df = pd.read_csv(csv_path, delimiter=";") | |
df_train = df[df["Filename"].isin(train)] | |
df_test = df[df["Filename"].isin(test)] | |
train_csv_path = train_class_folder + '/' + csv_name | |
test_csv_path = test_class_folder + '/' + csv_name | |
df_train.to_csv(train_csv_path, index=False, sep=";") | |
df_test.to_csv(test_csv_path, index=False, sep=";") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment