Created
September 29, 2015 13:24
-
-
Save caulagi/1a2c53ddfbba0a470482 to your computer and use it in GitHub Desktop.
A module to download (image) URLs using concurrent.futures
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
picit is a module to download URLs contained in an input file | |
to a local directory. | |
>>> import picit | |
>>> downloader = Downloader("/home/pcaulagi/src/picit/input.dat", | |
output_dir="/home/pcaulagi/Downloads") | |
>>> downloader.download() | |
>>> | |
""" | |
import urllib.request | |
import logging | |
import os | |
import shutil | |
import tempfile | |
from urllib.error import ContentTooShortError, URLError | |
from urllib.parse import urlparse | |
# Set a global timeout so that we don't wait for a single | |
# misbehaving server for too long. Note that timeout doesn't | |
# apply for DNS resolution. | |
import socket | |
socket.setdefaulttimeout(5) | |
class InvalidInputFileException(ValueError): | |
"""Exception class to signify invalid input file""" | |
pass | |
class InvalidOutputDirectoryException(ValueError): | |
"""Exception class to signify invalid output directory""" | |
pass | |
class Downloader(): | |
"""Class to download all URLs in an input file to a local directory""" | |
# Number of urls we will process at a time | |
CHUNK_SIZE = 50 | |
def __init__(self, input_file, **kwargs): | |
"""Initialize the downloader with the input file and optional params""" | |
if not os.path.isfile(input_file): | |
raise InvalidInputFileException("Input file is invalid.") | |
output_dir = kwargs.get("output_dir") or os.getcwd() | |
if not Downloader.check_output_dir(output_dir): | |
raise InvalidOutputDirectoryException("Output directory is invalid") | |
if not kwargs.get("logger"): | |
log_level = kwargs.get("log_level") or logging.INFO | |
logger = self.get_new_logger(log_level) | |
else: | |
logger = kwargs.get("logger") | |
self._input_file = input_file | |
self.logger = logger | |
self._output_dir = output_dir | |
@classmethod | |
def check_output_dir(cls, output_dir): | |
"""If the directory exists and we have write permissions, return true. | |
Else, try to create output_dir and if that works, return true. | |
Else false""" | |
if not os.path.isdir(output_dir): | |
try: | |
os.makedirs(output_dir) | |
except OSError: | |
return False | |
return True | |
# Directory exists. Check we have write permissions | |
try: | |
fd, path = tempfile.mkstemp(dir=output_dir) | |
os.close(fd) | |
os.remove(path) | |
except PermissionError: | |
return False | |
return True | |
@classmethod | |
def get_new_logger(cls, log_level): | |
"""Create a new logger set to write to stdout""" | |
logger = logging.getLogger(__name__) | |
logger.setLevel(log_level) | |
ch = logging.StreamHandler() | |
logger.addHandler(ch) | |
return logger | |
@classmethod | |
def get_original_filename(cls, url): | |
"""Get the filename corresponding to this url""" | |
_, _, path, _, _, _ = urlparse(url) | |
return path.split("/")[-1].strip() | |
def _move_file(self, tmp_file, to_file): | |
"""Move the file from /tmp where the file is downloaded to the | |
output_dir""" | |
try: | |
shutil.move(tmp_file, os.path.join(self._output_dir, to_file)) | |
except OSError as e: | |
self.logger.warn("Error moving %s to %s: %s", | |
tmp_file, self._output_dir, e.strerror) | |
def get_urls(self, urls): | |
"""Download each url in the list of urls""" | |
for url in urls: | |
self.logger.debug("Downloading: %s", url) | |
try: | |
local_file, _ = urllib.request.urlretrieve(url) | |
except (ContentTooShortError, URLError) as e: | |
self.logger.warn("Ignoring %s: %s", url, e.reason) | |
continue | |
self._move_file(local_file, self.get_original_filename(url)) | |
urllib.request.urlcleanup() | |
self.logger.debug("done") | |
def download(self): | |
"""Open the input file and download each url in turn""" | |
self.logger.info("Downloading files from %s to %s", | |
self._input_file, self._output_dir) | |
with open(self._input_file, "r") as input_file: | |
urls = [] | |
for line in input_file: | |
urls.append(line.strip()) | |
if len(urls) > self.CHUNK_SIZE: | |
self.get_urls(urls) | |
urls = [] | |
# May have some last few urls that haven't been processed | |
self.get_urls(urls) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment