Skip to content

Instantly share code, notes, and snippets.

@aakashns
Created March 12, 2018 09:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aakashns/a9a0eba8bb9541d31114ba1f3531944c to your computer and use it in GitHub Desktop.
Save aakashns/a9a0eba8bb9541d31114ba1f3531944c to your computer and use it in GitHub Desktop.
#!/usr/bin/python
# Note to Kagglers: This script will not run directly in Kaggle kernels. You
# need to download it and run it on your local machine.
# Downloads images from the Google Landmarks dataset using multiple threads.
# Images that already exist will not be downloaded again, so the script can
# resume a partially completed download. All images will be saved in the JPG
# format with 90% compression quality.
import sys, os, multiprocessing, urllib2, csv
from PIL import Image
from StringIO import StringIO
def ParseData(data_file):
csvfile = open(data_file, 'r')
csvreader = csv.reader(csvfile)
key_url_list = [line[:2] for line in csvreader]
return key_url_list[1:] # Chop off header
def DownloadImage(key_url):
out_dir = sys.argv[2]
(key, url) = key_url
filename = os.path.join(out_dir, '%s.jpg' % key)
if os.path.exists(filename):
print('Image %s already exists. Skipping download.' % filename)
return
try:
response = urllib2.urlopen(url)
image_data = response.read()
except:
print('Warning: Could not download image %s from %s' % (key, url))
return
try:
pil_image = Image.open(StringIO(image_data))
except:
print('Warning: Failed to parse image %s' % key)
return
try:
pil_image_rgb = pil_image.convert('RGB')
except:
print('Warning: Failed to convert image %s to RGB' % key)
return
try:
pil_image_rgb.save(filename, format='JPEG', quality=90)
except:
print('Warning: Failed to save image %s' % filename)
return
def Run():
if len(sys.argv) != 3:
print('Syntax: %s <data_file.csv> <output_dir/>' % sys.argv[0])
sys.exit(0)
(data_file, out_dir) = sys.argv[1:]
if not os.path.exists(out_dir):
os.mkdir(out_dir)
key_url_list = ParseData(data_file)
pool = multiprocessing.Pool(processes=50)
pool.map(DownloadImage, key_url_list)
if __name__ == '__main__':
Run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment