Created
August 14, 2021 00:38
-
-
Save AranKomat/ad4808f55ece4b85871a3ab1cc9ce41b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import zipfile | |
from pathlib import Path | |
import threading as th | |
from multiprocessing.pool import ThreadPool | |
import cv2 | |
import urllib | |
from contextlib import contextmanager | |
from datadings.tools.cached_property import cached_property | |
from string import hexdigits | |
import itertools as it | |
import multiprocessing as mp | |
from operator import itemgetter | |
import gzip | |
import json | |
import requests | |
import os | |
import random | |
from tqdm import tqdm | |
def load_metadata(indir, shard): | |
with gzip.open(indir / (shard + '.gz'), 'rt', encoding='utf-8') as fp: | |
for line in fp: | |
yield json.loads(line) | |
def make_shard_names(): | |
digits = set(hexdigits.lower()) | |
return {''.join(c) for c in it.product(digits, digits, digits)} | |
def shard_finished(shard, outdir): | |
with (outdir / 'finished_shards').open('at', encoding='utf-8') as fp: | |
fp.write(shard+'\n') | |
def load_finished_shards(outdir): | |
try: | |
with (outdir / 'finished_shards').open('rt', encoding='utf-8') as fp: | |
return {line.strip('\n') for line in fp} | |
except FileNotFoundError: | |
return set() | |
def find_meta_shards(indir): | |
return set(p.stem for p in indir.glob('*.gz')) & make_shard_names() | |
def resize(im, desired_size): | |
old_size = im.shape[:2] # old_size is in (height, width) format | |
ratio = float(desired_size)/min(old_size) | |
new_size = tuple([int(x*ratio) for x in old_size]) | |
# new_size should be in (width, height) format | |
return cv2.resize(im, (new_size[1], new_size[0])) | |
class WorkerBase: | |
def __init__( | |
self, | |
shards, | |
indir, | |
metadir, | |
outdir, | |
kinds, | |
filter_code, | |
processes, | |
threads, | |
num_out_dirs, | |
size, | |
image_max_size, | |
image_min_size, | |
): | |
self.positions = None | |
self.shards = shards | |
self.indir = indir | |
self.metadir = metadir | |
self.outdir = outdir | |
self.kinds = kinds | |
self.filter_code = filter_code | |
self.processes = processes | |
self.threads = threads | |
self.num_out_dirs = num_out_dirs | |
self.size = size | |
self.image_max_size = image_max_size | |
self.image_min_size = image_min_size | |
@cached_property | |
def filter_fun(self): | |
return eval(self.filter_code) | |
@contextmanager | |
def positioned(self): | |
with mp.Manager() as manager: | |
if self.processes > 0: | |
self.positions = manager.Queue() | |
for p in range(1, self.processes+1): | |
self.positions.put(p) | |
yield | |
@contextmanager | |
def position(self): | |
if self.positions: | |
position = self.positions.get() | |
try: | |
yield position | |
finally: | |
self.positions.put(position) | |
else: | |
yield 0 | |
def prepare_metadata(self, shard): | |
# load metadata and sort by key | |
metadata = load_metadata(self.metadir, shard) | |
metadata = filter(lambda s: s['marker'] in self.kinds, metadata) | |
metadata = filter(self.filter_fun, metadata) | |
return sorted(metadata, key=itemgetter('key')) | |
def tqdm(self, iterable, desc, position, length=None): | |
if length is None: | |
length = len(iterable) | |
return tqdm( | |
iterable, | |
desc=f'{"├" if position < self.processes else "└"} {desc}', | |
total=length, | |
smoothing=0, | |
position=position, | |
leave=position == 0, | |
) | |
def pool(self): | |
write_lock = mp.Lock() | |
tqdm.set_lock(write_lock) | |
return mp.Pool( | |
self.processes, | |
initializer=tqdm.set_lock, | |
initargs=(write_lock,), | |
maxtasksperchild=1, | |
) | |
class Downloader(WorkerBase): | |
def download(self, sample): | |
# `sample` contains all the meta data of a given sample. | |
# 1. retrieve the url for regular yfcc image from meta | |
# 2. download the page source | |
# 3. find the url for the yfcc image of your choice | |
r = requests.get(sample["pageurl"]).text | |
count = 0 | |
for seq in r.split(","): | |
if f"_{self.size}.jpg" in seq: | |
filename = seq.split("/")[-1][:-1] | |
url = f"https://live.staticflickr.com/1/" + filename | |
count += 1 | |
break | |
if count == 0: | |
return None | |
# ------------------------------- | |
# download the image of a given url | |
request = urllib.request.Request( | |
url, | |
data=None, | |
headers={'User-Agent': 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0'} | |
) | |
content = urllib.request.urlopen(request, timeout=10).read() | |
img_path = os.path.join(self.outdir, str(random.randint(0, self.num_out_dirs)), filename) | |
with open(img_path, 'wb') as outfile: | |
outfile.write(content) | |
# ------------------------------- | |
# resize if necessary and save as jpg | |
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
min_size = min(img.shape[:2]) | |
if min_size < self.image_min_size: | |
return None | |
elif min_size > self.image_max_size: | |
img = resize(img, self.image_max_size) | |
cv2.imwrite(img_path, img, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) | |
return url | |
def download_file(self, sample): | |
try: | |
return self.download(sample) | |
except Exception: | |
pass | |
def download_files(self, shard, files): | |
pool = ThreadPool(self.threads) | |
with self.position() as position: | |
yield from self.tqdm( | |
pool.imap(self.download_file, files), | |
shard, | |
position, | |
length=len(files), | |
) | |
def download_shard(self, shard): | |
metadata = self.prepare_metadata(shard) | |
for data in self.download_files(shard, metadata): | |
pass | |
return shard | |
def download_shards(self): | |
with self.positioned(), self.pool() as pool: | |
yield from pool.imap_unordered(self.download_shard, self.shards) | |
def download_parallel( | |
indir, | |
outdir, | |
shards=None, | |
kinds=(0, 1), | |
filter_code='lambda x: True', | |
processes=8, | |
threads=32, | |
num_out_dirs=10000, | |
size="h", | |
image_max_size=1024, | |
image_min_size=1024, | |
): | |
if not shards: | |
shards = find_meta_shards(indir) | |
finished_shards = load_finished_shards(outdir) | |
shards -= finished_shards | |
shards = sorted(shards) | |
downloader = Downloader( | |
shards, indir, indir, outdir, kinds, filter_code, processes, threads, num_out_dirs, size, image_max_size, image_min_size | |
) | |
gen = tqdm( | |
downloader.download_shards(), | |
desc='total', | |
total=len(shards), | |
smoothing=0, | |
position=0, | |
) | |
for shard in gen: | |
shard_finished(shard, outdir) | |
def main(): | |
from datadings.tools.argparse import make_parser | |
parser = make_parser( | |
__doc__, | |
no_confirm=False, | |
skip_verification=False, | |
shuffle=False, | |
) | |
parser.add_argument( | |
'-p', '--processes', | |
default=8, | |
type=int, | |
help='Number of shards downloaded in parallel.' | |
) | |
parser.add_argument( | |
'-t', '--threads', | |
default=32, | |
type=int, | |
help='Number of threads to download each shard.' | |
) | |
parser.add_argument( | |
'--num_out_dirs', | |
default=10000, | |
type=int, | |
help='Number of sub-directories to download images to.' | |
) | |
parser.add_argument( | |
'--size', | |
default="h", | |
type=str, | |
help='postfix for the size of image to download. \ | |
see https://www.flickr.com/services/api/misc.urls.html for what options are available.' | |
) | |
parser.add_argument( | |
'--image_max_size', | |
default=1024, | |
type=int, | |
help='largest possible size of the smaller side of image beyond which the image is scaled down.' | |
) | |
parser.add_argument( | |
'--image_min_size', | |
default=1024, | |
type=str, | |
help='smallest possible size of the smaller side of image below which the image is not downloaded.' | |
) | |
def _kind(v): | |
return {'images': 0, 'videos': 1}[v] | |
# noinspection PyTypeChecker | |
parser.add_argument( | |
'--kind', | |
nargs='+', | |
type=_kind, | |
choices=('images', 'videos'), | |
default=(0,), | |
help='Kinds of files to download. Defaults to images.' | |
) | |
parser.add_argument( | |
'--shard', | |
default=(), | |
type=str, | |
nargs='+', | |
help='Specify individual shards to download.' | |
) | |
def _evalable(v): | |
eval(v) | |
return v | |
parser.add_argument( | |
'--filter', | |
type=_evalable, | |
default='lambda x: True', | |
help='Lambda function to select samples.' | |
) | |
args = parser.parse_args() | |
indir = Path(args.indir) | |
outdir = Path(f"{args.outdir}/yfcc_{args.size}") if args.outdir else args.indir | |
os.makedirs(outdir, exist_ok=True) | |
for idx in range(args.num_out_dirs): | |
os.makedirs(os.path.join(outdir, str(idx)), exist_ok=True) | |
try: | |
download_parallel( | |
indir, | |
outdir, | |
args.shard, | |
args.kind, | |
args.filter, | |
args.processes, | |
args.threads, | |
args.num_out_dirs, | |
args.size, | |
args.image_max_size, | |
args.image_min_size, | |
) | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment