Last active
June 2, 2018 11:37
-
-
Save arjun-kava/682b77e3a8218541afb758d11a6f481b to your computer and use it in GitHub Desktop.
python 3 assemble_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Form a subset of the Flickr Style data, download images to dirname, and write | |
Caffe ImagesDataLayer training file. | |
""" | |
import os | |
import urllib.request | |
import hashlib | |
import argparse | |
import numpy as np | |
import pandas as pd | |
from skimage import io | |
import multiprocessing | |
# Flickr returns a special image if the request is unavailable. | |
MISSING_IMAGE_SHA1 = '6a92790b1c2a301c6e7ddef645dca1f53ea97ac2' | |
example_dirname = os.path.abspath(os.path.dirname(__file__)) | |
caffe_dirname = os.path.abspath(os.path.join(example_dirname, '../..')) | |
training_dirname = os.path.join(caffe_dirname, 'data/ShirtTshirt') | |
def download_image(args_tuple): | |
"For use with multiprocessing map. Returns filename on fail." | |
try: | |
url, filename = args_tuple | |
if not os.path.exists(filename): | |
urllib.request.urlretrieve(url, filename) | |
with open(filename) as f: | |
assert hashlib.sha1(f.read()).hexdigest() != MISSING_IMAGE_SHA1 | |
test_read_image = io.imread(filename) | |
return True | |
except KeyboardInterrupt: | |
print("key ex") | |
raise Exception() # multiprocessing doesn't catch keyboard exceptions | |
except: | |
#print("ex", args_tuple) | |
#raise Exception() | |
return False | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser( | |
description='Download a subset of Flickr Style to a directory') | |
parser.add_argument( | |
'-s', '--seed', type=int, default=0, | |
help="random seed") | |
parser.add_argument( | |
'-i', '--images', type=int, default=-1, | |
help="number of images to use (-1 for all [default])", | |
) | |
parser.add_argument( | |
'-w', '--workers', type=int, default=-1, | |
help="num workers used to download images. -x uses (all - x) cores [-1 default]." | |
) | |
parser.add_argument( | |
'-l', '--labels', type=int, default=0, | |
help="if set to a positive value, only sample images from the first number of labels." | |
) | |
args = parser.parse_args() | |
np.random.seed(args.seed) | |
# Read data, shuffle order, and subsample. | |
csv_filename = os.path.join(example_dirname, 'flickr_style.csv.gz') | |
df = pd.read_csv(csv_filename, index_col=None, compression='gzip') | |
df = df.iloc[np.random.permutation(df.shape[0])] | |
if args.labels > 0: | |
df = df.loc[df['label'] < args.labels] | |
if args.images > 0 and args.images < df.shape[0]: | |
df = df.iloc[:args.images] | |
# Make directory for images and get local filenames. | |
if training_dirname is None: | |
training_dirname = os.path.join(caffe_dirname, 'data/DogsCats') | |
images_dirname = os.path.join(training_dirname, 'images') | |
if not os.path.exists(images_dirname): | |
os.makedirs(images_dirname) | |
df['image_filename'] = [ | |
os.path.join(images_dirname, value) for value in df['image_name'] | |
] | |
# Download images. | |
num_workers = args.workers | |
if num_workers <= 0: | |
num_workers = multiprocessing.cpu_count() + num_workers | |
print('Downloading {} images with {} workers...'.format( | |
df.shape[0], num_workers)) | |
pool = multiprocessing.Pool(processes=num_workers) | |
map_args = zip(df['image_url'], df['image_filename']) | |
results = pool.map(download_image, map_args) | |
# Only keep rows with valid images, and write out training file lists. | |
df = df[results] | |
for split in ['train', 'test']: | |
split_df = df[df['_split'] == split] | |
filename = os.path.join(training_dirname, '{}.txt'.format(split)) | |
split_df[['image_filename', 'label']].to_csv( | |
filename, sep=' ', header=None, index=None) | |
print('Writing train/val for {} successfully downloaded images.'.format( | |
df.shape[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment