Skip to content

Instantly share code, notes, and snippets.

@robinkraft
Last active November 21, 2016 01:46
Show Gist options
  • Save robinkraft/82e565a9b8a84fdf1e5c43c53142be51 to your computer and use it in GitHub Desktop.
Save robinkraft/82e565a9b8a84fdf1e5c43c53142be51 to your computer and use it in GitHub Desktop.
Data prep for Kaggle's State Farm competition
import os
from glob import glob
import shutil
import pandas as pd
def choose_subjects(df, count):
# group the frame by the subject in the image
subjects = df.groupby('subject')
# randomize groups, and pick a few (given by 'count' variable)
all_subjects = subjects.groups.keys()
picked = np.random.permutation(all_subjects)[:count]
subject_list = []
# append each group to the list
for p in picked:
subject_list.append(subjects.get_group(p))
return subject_list
def mk_class_dirs(basepath, start, end, prefix='c'):
for i in range(start, end + 1):
path = os.path.join(basepath, '{}{}'.format(prefix, i))
if not os.path.exists(path):
os.makedirs(path)
return
def revert_valid(path):
valid_path = os.path.join(path, 'valid')
glob_path = os.path.join(valid_path, '*', '*')
for source in glob(glob_path):
target = source.replace('valid', 'train')
#print('mv {} {}'.format(source, target))
os.rename(source, target)
if not glob(glob_path):
print 'No files left - finishing cleanup'
shutil.rmtree(os.path.join(path, 'valid'))
return True
else:
return False
def mk_valid(path, subject_count=1, cleanup_only=False):
if os.path.exists(os.path.join(path, 'valid')):
if not revert_valid(path):
raise Exception('Not all previous files cleaned up properly')
if not cleanup_only:
# read the CSV file into a pandas DataFrame
df = pd.read_csv(path + 'driver_imgs_list.csv')
subjects = choose_subjects(df, subject_count)
# set up c0-c9 directory structure in valid folder
valid_path = os.path.join(path, 'valid')
mk_class_dirs(valid_path, 0, 9, 'c')
# loop over each subject and each image, moving images into
# the validation directory
for subject in subjects:
print 'Expect {} files to move'.format(subject.shape[0])
n = 0
for (subject, cls, img) in subject.values:
source = os.path.join(path, 'train', cls, img)
target = source.replace('train', 'valid')
# print('mv {} {}'.format(source, target))
os.rename(source, target)
n += 1
print 'Moved {} files'.format(n)
return
def copy_sample(random_paths, basepath, train_or_valid, maxidx, minidx=0):
sample = random_paths[minidx:maxidx]
n = 0
for f in sample:
path, fname = os.path.split(f)
img_class = os.path.split(path)[1]
if train_or_valid == 'valid':
source = os.path.join(basepath, 'train', img_class, fname)
else:
source = os.path.join(basepath, train_or_valid, img_class, fname)
target = os.path.join(basepath, 'sample', train_or_valid, img_class, fname)
if n == 0:
parent = os.path.join(basepath, 'sample', train_or_valid)
mk_class_dirs(parent, 0, 9, 'c')
#print 'cp {} {}'.format(source, target)
shutil.copy(source, target)
n += 1
print 'Copied {} files to {} sample'.format(n, train_or_valid)
return
def mk_sample(path, train_size=200, valid_size=40):
sample_path = os.path.join(path, 'sample')
if os.path.exists(sample_path):
print('Deleting existing sample data')
shutil.rmtree(sample_path)
train_path_sample = os.path.join(sample_path, 'train')
valid_path_sample = os.path.join(sample_path, 'valid')
os.makedirs(train_path_sample)
os.makedirs(valid_path_sample)
glob_path = os.path.join(path, 'train', '*', '*')
raw_paths = glob(glob_path)
random_paths = np.random.permutation(raw_paths)
copy_sample(random_paths, path, 'train', train_size)
copy_sample(random_paths, path, 'valid', valid_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment