Skip to content

Instantly share code, notes, and snippets.

@wbuchwalter
Last active February 21, 2020 16:05
Show Gist options
  • Save wbuchwalter/d65fc12fd19a6af7f98988a00b0c7ad0 to your computer and use it in GitHub Desktop.
Save wbuchwalter/d65fc12fd19a6af7f98988a00b0c7ad0 to your computer and use it in GitHub Desktop.
import argparse
import os
import csv
import json
from urllib.request import urlopen
from multiprocessing import Pool
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('input', help='Path to GCC TSV input file')
parser.add_argument('output', help='Output directory')
parser.add_argument('-t', help='Number of threads, default: 20', default=20, type=int)
parser.add_argument('-f', help='Frequency of checkpoints, default: 1000', default=1000, type=int)
args = parser.parse_args()
# Trailing slash mess with basename
assert args.output[-1] != '/', "Output path must not end by a trailing slash"
REQUEST_TIMEOUT = 0.5 # 0.5 sec
# COCO format is like:
# {
# "images": [{"id": 12, "filename": "00012.jpg"}],
# "annotations": [{"id": 12, "image_id": 12, "caption": "some piece of text"}]
# }
anns = {
'annotations': [],
'images': []
}
cursor_start = 0
cursor_pos = 0
dest_ann_file = os.path.join(args.output, 'captions_{}.json'.format(os.path.basename(args.output)))
dest_img_dir = os.path.join(args.output, os.path.basename(args.output))
# If cache file exists, seek over it to find the last processed index
# Then set start_cursor to this value. if idx < start_cursor fast forward
if os.path.isfile(dest_ann_file):
anns = json.load(open(dest_ann_file, 'r'))
cursor_start = anns['images'][-1]['id']
print('Output destination already exists, resuming download from image # %i...' % cursor_start)
else:
os.makedirs(dest_img_dir, exist_ok=True)
def fetch_image_data(url):
response = urlopen(url, timeout=REQUEST_TIMEOUT)
if response.status != 200:
raise Exception('Bad status code')
img_data = response.read()
header = img_data[:11]
if (header[:3] != b'\xff\xd8\xff'):
# If for some reason the header does not look like JPEG (a redirect for a broken image but 200 status for example) we skip
raise Exception('Corrupted image')
return img_data
def process_image(tup):
img_id, caption, url = tup
try:
img_data = fetch_image_data(url)
except Exception as e:
return None
img_filename = "{0:07d}.jpg".format(img_id)
open(os.path.join(dest_img_dir, img_filename), 'w+b').write(img_data)
ann = {"id": img_id, "image_id": img_id, "caption": caption}
img = {"id": img_id, "file_name": img_filename}
return (ann, img)
with open(args.input, 'r') as tsvin:
buffer = []
tsvin = csv.reader(tsvin, delimiter='\t')
for cursor_pos, (caption, url) in enumerate(tsvin):
if cursor_pos < cursor_start:
# Fast forward to cursor_start when resuming a download
continue
# Fill up a batch
buffer.append((cursor_pos, caption, url))
processing_cursor = 0
nb_failed = 0
total_processing_time = 0
with Pool(args.t) as p:
while processing_cursor < len(buffer):
t0 = time.time()
batch = buffer[processing_cursor : processing_cursor + args.f]
res_iterator = p.imap_unordered(process_image, batch)
valid_data = []
nb_failed_batch = 0
for res in res_iterator:
if res is None: # Happens if there was an issue downloading the image
nb_failed_batch += 1
else:
valid_data.append(res)
valid_data = np.array(valid_data)
anns['annotations'].extend(valid_data[:, 0].tolist())
anns['images'].extend(valid_data[:, 1].tolist())
json.dump(anns, open(dest_ann_file, 'w+'))
processing_cursor += len(batch)
batch_time = time.time() - t0
total_processing_time += batch_time
t_per_sample = total_processing_time / processing_cursor
abs_pos = processing_cursor + cursor_start
eta = t_per_sample * (3300000 - abs_pos) / 3600
print("[Step]: {}, [Batch Time]: {:.1f}s., [ETA]: {:.2f} hours".format(abs_pos, batch_time, eta))
print("%i images failed to download over %i for this batch" % (nb_failed_batch, len(batch)))
nb_failed += nb_failed_batch
print("%i images failed to download" % nb_failed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment