Skip to content

Instantly share code, notes, and snippets.

@dennisobrien
Created November 13, 2016 03:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save dennisobrien/c72ef0f0c1fe125bb49e07b6b2834927 to your computer and use it in GitHub Desktop.
Save dennisobrien/c72ef0f0c1fe125bb49e07b6b2834927 to your computer and use it in GitHub Desktop.
from getpass import getpass
from glob import glob
import numpy as np
import os
import sh
import shutil
def get_data_dir(*args, relative=False):
"""Return the path to the data directory.
Any additional positional arguments are interpreted as paths in the root data directory.
Assumes the root data directory is relative to this notebook at `data/redux`.
"""
if relative:
return os.path.join('data', 'redux', *args)
else:
return os.path.join(os.path.abspath('data/redux'), *args)
def make_dirs(dir_path, create_msg=None, ignore_msg=None):
"""Create the directory `dir_path` if necessary."""
try:
os.makedirs(dir_path)
if create_msg:
print(create_msg)
except FileExistsError as e:
if ignore_msg:
print(ignore_msg)
def prepare_data(n_valid=2000):
"""Fetch data (if necessary) from kaggle using kaggle-cli.
Unzip the archives (if necessary) to the 'train' directory.
Move `n_valid` of the train images into the 'valid/' directory.
Create a directory 'unknown/' in 'test' and move all images from 'test/'.
"""
data_dir = get_data_dir()
make_dirs(data_dir)
if not os.path.exists(os.path.join(data_dir, 'train.zip')):
# `kg download -u USERNAME -p PASSWORD -c dogs-vs-cats-redux-kernels-edition`
print('fetching data from kaggle...')
username = getpass('kaggle username:')
password = getpass('kaggle password:')
kg = sh.Command('kg')
output = kg.download(u=username, p=password, c='dogs-vs-cats-redux-kernels-edition', _cwd=data_dir)
print(output)
else:
print('zip files already exist')
if not os.path.exists(os.path.join(data_dir, 'train')):
print('unzip the files')
for filename in ('train.zip', 'test.zip'):
output = sh.unzip(filename, _cwd=data_dir)
print(output[:5])
else:
print('archives already unzipped')
valid_dir = get_data_dir('valid')
train_dir = get_data_dir('train')
make_dirs(valid_dir, 'create the valid directory')
print('move all files from valid/ to train/')
g = glob(os.path.join(valid_dir, '*.jpg'))
move_files(g, train_dir)
print('move some files from train/ to valid/')
g = np.random.choice(glob(os.path.join(train_dir, '*.jpg')), n_valid, replace=False)
move_files(g, valid_dir)
test_dir = get_data_dir('test')
test_category_dir = os.path.join(test_dir, 'unknown')
make_dirs(test_category_dir)
g = glob(os.path.join(test_dir, '*.jpg'))
move_files(g, test_category_dir)
def move_files(src_list, dest_dir):
"""Move all the filepaths in `src_list` to `dest_dir`.
"""
for src in src_list:
filename = os.path.basename(src)
dest = os.path.join(dest_dir, filename)
os.rename(src, dest)
print('moved {} files to {}'.format(len(src_list), dest_dir))
def prepare_sample_data(f_sample=0.1, f_train=0.75):
"""Move a subset of the data from 'train/' to 'sample/train/' and 'sample/valid/'.
"""
sample_dir = get_data_dir('sample')
try:
shutil.rmtree(sample_dir)
print('deleted {}'.format(sample_dir))
except FileNotFoundError as e:
pass
train_dir = os.path.join(sample_dir, 'train')
valid_dir = os.path.join(sample_dir, 'valid')
os.makedirs(train_dir)
os.makedirs(os.path.join(train_dir, 'cats'))
os.makedirs(os.path.join(train_dir, 'dogs'))
os.makedirs(valid_dir)
os.makedirs(os.path.join(valid_dir, 'cats'))
os.makedirs(os.path.join(valid_dir, 'dogs'))
filepaths = glob(get_data_dir('train', '*.jpg'))
n_files = int(len(filepaths) * f_sample)
g = np.random.choice(filepaths, n_files, replace=False)
f_train = 0.75
n_sample_train = int(len(g) * f_train)
copy_files([f for f in g[:n_sample_train] if (os.path.basename(f)).startswith('cat')],
os.path.join(train_dir, 'cats'))
copy_files([f for f in g[:n_sample_train] if (os.path.basename(f)).startswith('dog')],
os.path.join(train_dir, 'dogs'))
copy_files([f for f in g[n_sample_train:] if (os.path.basename(f)).startswith('cat')],
os.path.join(valid_dir, 'cats'))
copy_files([f for f in g[n_sample_train:] if (os.path.basename(f)).startswith('dog')],
os.path.join(valid_dir, 'dogs'))
def copy_files(src_list, dest_dir):
for src in src_list:
filename = os.path.basename(src)
dest = os.path.join(dest_dir, filename)
shutil.copyfile(src, dest)
print('copied {} files to {}'.format(len(src_list), dest_dir))
prepare_data()
prepare_sample_data(f_sample=1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment