Skip to content

Instantly share code, notes, and snippets.

@arjun-kava
Last active June 2, 2018 11:37
Show Gist options
  • Save arjun-kava/682b77e3a8218541afb758d11a6f481b to your computer and use it in GitHub Desktop.
Save arjun-kava/682b77e3a8218541afb758d11a6f481b to your computer and use it in GitHub Desktop.
python 3 assemble_data.py
#!/usr/bin/env python
"""
Form a subset of the Flickr Style data, download images to dirname, and write
Caffe ImagesDataLayer training file.
"""
import os
import urllib.request
import hashlib
import argparse
import numpy as np
import pandas as pd
from skimage import io
import multiprocessing
# Flickr returns a special image if the request is unavailable.
MISSING_IMAGE_SHA1 = '6a92790b1c2a301c6e7ddef645dca1f53ea97ac2'
example_dirname = os.path.abspath(os.path.dirname(__file__))
caffe_dirname = os.path.abspath(os.path.join(example_dirname, '../..'))
training_dirname = os.path.join(caffe_dirname, 'data/ShirtTshirt')
def download_image(args_tuple):
"For use with multiprocessing map. Returns filename on fail."
try:
url, filename = args_tuple
if not os.path.exists(filename):
urllib.request.urlretrieve(url, filename)
with open(filename) as f:
assert hashlib.sha1(f.read()).hexdigest() != MISSING_IMAGE_SHA1
test_read_image = io.imread(filename)
return True
except KeyboardInterrupt:
print("key ex")
raise Exception() # multiprocessing doesn't catch keyboard exceptions
except:
#print("ex", args_tuple)
#raise Exception()
return False
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Download a subset of Flickr Style to a directory')
parser.add_argument(
'-s', '--seed', type=int, default=0,
help="random seed")
parser.add_argument(
'-i', '--images', type=int, default=-1,
help="number of images to use (-1 for all [default])",
)
parser.add_argument(
'-w', '--workers', type=int, default=-1,
help="num workers used to download images. -x uses (all - x) cores [-1 default]."
)
parser.add_argument(
'-l', '--labels', type=int, default=0,
help="if set to a positive value, only sample images from the first number of labels."
)
args = parser.parse_args()
np.random.seed(args.seed)
# Read data, shuffle order, and subsample.
csv_filename = os.path.join(example_dirname, 'flickr_style.csv.gz')
df = pd.read_csv(csv_filename, index_col=None, compression='gzip')
df = df.iloc[np.random.permutation(df.shape[0])]
if args.labels > 0:
df = df.loc[df['label'] < args.labels]
if args.images > 0 and args.images < df.shape[0]:
df = df.iloc[:args.images]
# Make directory for images and get local filenames.
if training_dirname is None:
training_dirname = os.path.join(caffe_dirname, 'data/DogsCats')
images_dirname = os.path.join(training_dirname, 'images')
if not os.path.exists(images_dirname):
os.makedirs(images_dirname)
df['image_filename'] = [
os.path.join(images_dirname, value) for value in df['image_name']
]
# Download images.
num_workers = args.workers
if num_workers <= 0:
num_workers = multiprocessing.cpu_count() + num_workers
print('Downloading {} images with {} workers...'.format(
df.shape[0], num_workers))
pool = multiprocessing.Pool(processes=num_workers)
map_args = zip(df['image_url'], df['image_filename'])
results = pool.map(download_image, map_args)
# Only keep rows with valid images, and write out training file lists.
df = df[results]
for split in ['train', 'test']:
split_df = df[df['_split'] == split]
filename = os.path.join(training_dirname, '{}.txt'.format(split))
split_df[['image_filename', 'label']].to_csv(
filename, sep=' ', header=None, index=None)
print('Writing train/val for {} successfully downloaded images.'.format(
df.shape[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment