Skip to content

Instantly share code, notes, and snippets.

@simonamdev
Created July 1, 2023 10:26
Show Gist options
  • Save simonamdev/7df8875642d9ab80b237b61b10de3666 to your computer and use it in GitHub Desktop.
Save simonamdev/7df8875642d9ab80b237b61b10de3666 to your computer and use it in GitHub Desktop.
Python script to generate dataest mixtures
import os
import random
from sklearn.model_selection import train_test_split
# LOCAL FILES
ROOT_DATASET_DIR = '/home/simon/Desktop/datasets'
print(os.listdir(ROOT_DATASET_DIR))
CAMVID_DIR = os.path.join(ROOT_DATASET_DIR, 'camvid')
SYNTHIA_DIR = os.path.join(ROOT_DATASET_DIR, 'synthia')
PFD_DIR = os.path.join(ROOT_DATASET_DIR, 'pfd')
dataset_path_map = {
'camvid': {
'train': os.path.join(CAMVID_DIR, 'train'),
'train_labels': os.path.join(CAMVID_DIR, 'train_labels'),
'test': os.path.join(CAMVID_DIR, 'test'),
'test_labels': os.path.join(CAMVID_DIR, 'test_labels'),
'val': os.path.join(CAMVID_DIR, 'val'),
'val_labels': os.path.join(CAMVID_DIR, 'val_labels'),
},
'pfd': {
'images': os.path.join(PFD_DIR, 'images'),
'labels': os.path.join(PFD_DIR, 'labels')
},
'synthia-rand': {
'images': os.path.join(SYNTHIA_DIR, 'RGB'),
'labels': os.path.join(SYNTHIA_DIR, 'GT'),
}
}
for dataset, dataset_keys in dataset_path_map.items():
for dataset_key in dataset_keys:
amount = os.listdir(dataset_path_map[dataset][dataset_key])
print(f'Dataset: {dataset} Key: {dataset_key} Amount: {len(amount)}')
dataset_splits = []
random_seed = 12345
random.seed(random_seed)
camvid_training_images = [os.path.join(CAMVID_DIR, 'train', f) for f in os.listdir(dataset_path_map['camvid']['train'])]
camvid_training_labels = [os.path.join(CAMVID_DIR, 'train_labels', f) for f in os.listdir(dataset_path_map['camvid']['train_labels'])]
for synthetic_dataset in ('pfd', 'synthia-rand'):
for percent_synthetic_added in range(10, 110, 10):
test_size = 1.0-(percent_synthetic_added/100.0)
print(f'Creating dataset with {synthetic_dataset} with {percent_synthetic_added}% ({test_size}) synthetic data appended')
files = {
'percent_synthetic_added': percent_synthetic_added,
'images': sorted([
os.path.join(dataset_path_map[synthetic_dataset]['images'], file_name) for file_name in os.listdir(dataset_path_map[synthetic_dataset]['images'])
]),
'labels': sorted([
os.path.join(dataset_path_map[synthetic_dataset]['labels'], file_name) for file_name in os.listdir(dataset_path_map[synthetic_dataset]['labels'])
]),
'synthetic_dataset': synthetic_dataset
}
data_needed = files['images']
labels_needed = files['labels']
if percent_synthetic_added != 100:
data_needed, _, labels_needed, _ = train_test_split(
files['images'], files['labels'], test_size=test_size, random_state=random_seed
)
text_file_path = f'./datasets/{str(percent_synthetic_added).zfill(3)}_percent_of_{synthetic_dataset}_added.txt'
if os.path.exists(text_file_path):
os.remove(text_file_path)
for source_file_paths in [data_needed, labels_needed]:
with open(text_file_path, 'a') as f:
f.writelines(
[f'{line}\n' for line in source_file_paths]
)
# Add camvid in
for source_file_paths in [camvid_training_images, camvid_training_labels]:
with open(text_file_path, 'a') as f:
f.writelines(
[f'{line}\n' for line in source_file_paths]
)
print('Done!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment