Skip to content

Instantly share code, notes, and snippets.

/e6collector.py Secret

Created July 4, 2017 03:52
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/f9936e74cedca08368561e3e6d505b91 to your computer and use it in GitHub Desktop.
Save anonymous/f9936e74cedca08368561e3e6d505b91 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
e6collector.
Python 3.3 script that will download all images with a specific
tag, save them named ID-SOME-TAGS.EXTENSION and makes sure it doesn't
download the same image multiple times because the tags changed. Also
all tags will be written to tags.csv in the same folder. Using the
Ouroboros API described on https://e621.net/help/api
PROTIP: To view images with a certain tag try this on UNIX-like systems:
gthumb `for f in $(grep -i TAG tags.csv | cut -f 1 -d ","); do echo $f-*; done`
THE AUTHOR DOES NOT TAKE ANY RESPONSIBILITIES, THERE IS NO WARRANTY AND
YOU PROBABLY SHOULDN'T USE THIS IN A NUCLEAR POWER PLANT, JUST SAYING!
License: Public domain, do whatever.
Version 1.0 -- Initial release, if you can call it that
Version 1.0.1 -- Fixed Unicode problem on Windows
Version 1.0.2 -- Fixed API not working
Version 2.0.0 -- Rerwitten. Better logging, parallel downloads.
Version 2.0.1 -- Use `before_id` instead of `page` when listing posts
Version 2.0.2 -- Rate limiting! Error handling!
Version 2.0.3 -- Faster. Quieter. Prints stats.
"""
import argparse
import csv
import glob
import logging
import os.path
import re
import time
from collections import defaultdict
from contextlib import contextmanager
from functools import partial, wraps
from multiprocessing import Lock
from multiprocessing.dummy import Pool
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from urllib.error import HTTPError, ContentTooShortError
from xml.etree import ElementTree
_LOG = logging.getLogger('e6collector')
UNSAFE_TAG_CHARS = re.compile(r'[^0-9a-z_-]+')
def escape_tag(tag):
"""Strip unsafe characters from a tag."""
return UNSAFE_TAG_CHARS.sub('', tag)
class RetryLimitReached(Exception):
"""Raised when a retry decorated function hits its limit."""
def __init__(self, fn, limit, exceptions):
"""
:param callable fn:
:param int limit:
:param exceptions:
:type exceptions: list[Exception]
"""
self._fn = fn
self._limit = limit
self._exceptions = exceptions
self._start_time = None
msg = '{fn} failed {limit:d} times: {exceptions}'.format(
fn=self._fn,
limit=self._limit,
exceptions=self._exceptions)
super().__init__(msg)
def retry(limit, delay=1.0, backoff=2.0, exceptions=(Exception, )):
"""
Function decorator for retrying on exception, with exponential backoff.
:param limit int:
:param float delay:
:param float backoff:
:param exceptions:
:type exceptions: list[Exception]
"""
exceptions = tuple(exceptions)
def _retry(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
_remaining_tries = limit
_delay = delay
_exceptions = []
while _remaining_tries:
try:
return fn(*args, **kwargs)
except exceptions as ex:
_LOG.debug(
"Exception %r occured calling %s; sleeping for %.02f",
ex, fn, _delay)
_exceptions.append(ex)
if _remaining_tries:
_remaining_tries -= 1
time.sleep(_delay)
_delay *= backoff
else:
raise RetryLimitReached(fn, limit, exceptions)
return wrapper
return _retry
def rate_limit(calls_per_second):
"""
Function decorator that rate-limits calls to a function.
:param float calls_per_second:
"""
lock = Lock()
interval = 1. / calls_per_second
def _rate_limit(fn):
last_call = 0
@wraps(fn)
def wrapper(*args, **kwargs):
lock.acquire()
try:
nonlocal last_call
elapsed = time.perf_counter() - last_call
wait = interval - elapsed
if wait > 0:
_LOG.debug("Rate limiting %s call to %s calls/s; "
"last call was %0.2f so sleep for %0.2f",
fn.__name__, calls_per_second, last_call, wait)
time.sleep(wait)
last_call = time.perf_counter()
finally:
lock.release()
return fn(*args, **kwargs)
return wrapper
return _rate_limit
class Counter():
"""Thread-safe counter."""
def __init__(self):
self._lock = Lock()
self._counts = defaultdict(int)
def __repr__(self):
return '<Counter({!r})>'.format(list(self._counts.items()))
def increment(self, key, by=1):
"""
Increment a count.
:param key: counter to increment
:param by: add this many
"""
with self._lock:
self._counts[key] += by
return self._counts[key]
def get(self, key):
"""
Get a count.
:param key:
"""
return self._counts[key]
__getitem__ = get
def counts(self):
"""
A stable view of the counts.
:rtype: Generator[(any, number)]
:return: Tuples of (key, count)
"""
with self._lock:
yield from self._counts.items()
__iter__ = counts
class E6Collector:
"""Download tagged posts from e621, and write their tags to a CSV file."""
FILE_PATTERN = '{0:s}-{1:s}.{2:s}'
LIST_ENDPOINT = '/post/index.xml'
REQUEST_HEADERS = {'User-Agent': 'e6collector/2.0.3'}
URI_PREFIX = 'https://e621.net'
def __init__(self, jobs=None):
"""
:param jobs: Number of threads to use
:type jobs: int or None
"""
self._stats = Counter()
self._written_tags = set()
self._written_posts = set()
self.jobs = jobs
@retry(limit=5,
delay=2,
backoff=4,
exceptions=(HTTPError, ContentTooShortError))
@rate_limit(1.0)
def _make_request(self, request):
self._stats.increment('request')
return urlopen(request)
def _api_request(self, endpoint, query_params=None):
uri = self.URI_PREFIX + self.LIST_ENDPOINT
if query_params:
if 'tags' in query_params:
tags = query_params['tags']
if isinstance(tags, list):
query_params['tags'] = self._flatten_tags_param(tags)
query_string = urlencode(query_params)
uri += "?" + query_string
return Request(uri, headers=self.REQUEST_HEADERS)
def _flatten_tags_param(self, tags):
return ' '.join(tags)
def tagged_posts(self, tags):
"""
Fetch posts matching tags.
:param str tags:
:return: Generator of posts matching tags
:rtype: Generator[Element]
"""
post_count = 0
request_limit = 100
total_count = None
before_id = None
tags += ['order:-id']
params = {'tags': tags, 'limit': request_limit, }
_LOG.debug("Finding posts tagged %r", tags)
while True:
_LOG.debug('Fetching posts %d to %d of %s', post_count,
post_count + request_limit - 1, total_count or
'unknown')
request = self._api_request(self.LIST_ENDPOINT, params)
response = self._make_request(request)
xmltree = ElementTree.parse(response)
response.close()
_count = xmltree.getroot().get('count')
if _count:
if not total_count:
total_count = int(_count)
_LOG.info("Found %d posts", total_count)
page_posts = xmltree.findall('post')
self._stats.increment('post_seen', len(page_posts))
_LOG.debug("Got %d posts", len(page_posts))
for post in page_posts:
yield post
postid = int(post.find('id').text)
if before_id is None or postid < before_id:
before_id = postid
if len(page_posts) < request_limit:
_LOG.debug('No more posts')
break
post_count += len(page_posts)
params['before_id'] = before_id
def download_post(self, post):
"""
Fetch the image data for a post.
:param Element post:
:return: image data
:rtype: bytes
"""
self._stats.increment('post_download_start')
postid = post.find('id').text
posturl = post.find('file_url').text
request = Request(posturl, headers=self.REQUEST_HEADERS)
response = self._make_request(Request(posturl,
headers=self.REQUEST_HEADERS))
headers = response.info()
content_length = headers.get('content-length', '0')
content_length = int(content_length) / 1024.0
_LOG.info("Downloading post {id} ({kbytes:.2f}kB)".format(
id=postid,
kbytes=content_length))
imgdata = response.read()
response.close()
self._stats.increment('post_download_success')
return imgdata
def write_post(self, post, imgdata, destination):
"""
Write post image to disk, and ensure post tags are in tags.csv.
:param Element post:
:param bytes imgdata:
:param str destination: Path to folder in which to write
"""
postid = post.find('id').text
postext = post.find('file_ext').text
tags = filter(None, [
escape_tag(tag) for tag in post.find('tags').text.split(' ')
])
fname = self.FILE_PATTERN.format(postid, '-'.join(tags)[:190], postext)
path = os.path.join(destination, fname)
_LOG.debug("Writing post {id} to disk".format(id=postid))
with open(path, 'wb') as imgfile:
wrote = imgfile.write(imgdata)
self._stats.increment('post_write')
return wrote
def is_tag_written(self, postid):
"""
Check if the post is already in the tag file.
:param str postid:
:param str tag_path:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
:return: True if the post is already in the tag file
"""
return postid in self._written_tags
@staticmethod
def _tags_on_disk(destination):
tagged_posts = set()
tag_path = os.path.join(destination, 'tags.csv')
with open(tag_path, 'r') as tag_file:
for row in csv.reader(tag_file):
postid = str(row[0])
tagged_posts.add(postid)
_LOG.debug("%d posts in tag file", len(tagged_posts))
return tagged_posts
def update_tag_store(self, post, destination, tag_file_lock=None):
"""
Add post to the tag file.
:param Element post:
:param str destination:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
"""
postid = post.find('id').text
path = os.path.join(destination, 'tags.csv')
if self.is_tag_written(postid):
_LOG.debug("Post %s already in tag store", postid)
return
try:
source = post.find('source').text
except AttributeError:
try:
source = post.find('sources').text
except AttributeError:
source = ""
tags = post.find('tags').text.split(' ')
_LOG.debug("Writing tags for {id!r} ({tags}) to tags.csv".format(
id=postid,
tags=', '.join(tags), ))
if tag_file_lock:
tag_file_lock.acquire()
try:
with open(path, 'a', encoding='utf-8') as tagfile:
writer = csv.writer(tagfile, quoting=csv.QUOTE_MINIMAL)
row = [postid, source] + tags
writer.writerow(row)
self._written_tags.add(postid)
finally:
self._stats.increment('tag_file_write')
if tag_file_lock:
tag_file_lock.release()
@staticmethod
def _posts_on_disk(destination):
posts = set()
for name in os.listdir(destination):
postid = name.split('-', 1)[0]
if postid.isdigit:
posts.add(postid)
_LOG.debug("Found %d posts on disk", len(posts))
return posts
def is_post_written(self, postid):
"""
:return: True if the post is already on disk
:rtype: bool
"""
return postid in self._written_posts
def fetch_post(self, post, destination, tag_file_lock=None):
"""
Download post, and write its tags to tag file.
:param Element post:
:param str destination:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
"""
try:
postid = post.find('id').text
poststatus = post.find('status').text
if poststatus == 'deleted':
_LOG.warning(postid + ' was deleted')
self._stats.increment('post_deleted')
return
if self.is_post_written(postid):
_LOG.debug('Post %s already on disk', postid)
self.update_tag_store(post, destination, tag_file_lock)
self._stats.increment('post_cached')
self._stats.increment('post_success')
return
imgdata = self.download_post(post)
disk_size = self.write_post(post, imgdata, destination)
self.update_tag_store(post, destination, tag_file_lock)
_LOG.info("Downloaded post %s (%.2fkB, tagged: %s)", postid,
disk_size / 1024.0,
', '.join(post.find('tags').text.split(' ')))
self._stats.increment('post_success')
except Exception:
_LOG.exception("Failed to fetch post %s", postid)
self._stats.increment('post_failed')
raise
def _print_stats(self):
stats = self._stats
print("{success}/{total} posts synced".format(
success=stats['post_success'],
total=stats['post_seen']),
end='')
details = []
if stats['post_download_success']:
details.append("{new} new".format(new=stats[
'post_download_success']))
if stats['post_cached']:
details.append("{cached} already on disk".format(cached=stats[
'post_cached']))
if stats['post_failed']:
details.append("{failed} failed".format(failed=stats[
'post_failed']))
if details:
print(" ({details})".format(details=', '.join(details)), end='')
print(" in {:.2f}s.".format(time.perf_counter() - self._start_time))
def mirror(self, tags, destination):
"""
Find posts matching tags, and download them to destination.
:param tags:
:type tags: str or list[str]
:param str destination:
"""
self._start_time = time.perf_counter()
_LOG.info("Downloading posts tagged %r to %s", tags, destination)
posts = self.tagged_posts(tags)
os.makedirs(destination, mode=0o700, exist_ok=True)
self._written_posts = self._posts_on_disk(destination)
self._written_tags = self._tags_on_disk(destination)
if self.jobs:
try:
def _except(ex):
_LOG.error("%s occured in worker thread: %s", type(ex), ex)
pool = Pool(self.jobs)
tag_file_lock = Lock()
kwargs = {
'destination': destination,
'tag_file_lock': tag_file_lock,
}
fetch_fn = partial(self.fetch_post,
destination=destination,
tag_file_lock=tag_file_lock)
results = pool.imap_unordered(fetch_fn, posts)
pool.close()
for result in results: # om nom nom
pass
pool.join()
except Exception as ex:
pool.terminate()
pool.join()
raise
finally:
_LOG.debug("Stats: %r", self._stats)
self._print_stats()
else:
try:
for post in posts:
self.fetch_post(post, destination)
finally:
_LOG.debug("Stats: %r", self._stats)
self._print_stats()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Download files by tag from e621')
parser.add_argument('destination', help='Directory to store the files in')
parser.add_argument('tags',
help='Tags to look for. Try "fav:yourname"',
nargs='+')
parser.add_argument('--jobs',
'-j',
help='Downloads to run in parallel',
type=int,
default=10)
parser.add_argument('--verbose', '-v', action='store_true')
parser.add_argument('--quiet', '-q', action='store_true')
args = parser.parse_args()
tags = args.tags
jobs = args.jobs
if jobs == 0:
jobs = None
destination = args.destination
log_format = '%(asctime)s %(levelname)s: %(message)s'
debug_log_format = (
'%(asctime)s %(threadName)s(%(filename)s:%(lineno)d in %(funcName)s): '
'[%(levelname)s] %(message)s')
if args.quiet:
logging.basicConfig(level=logging.WARNING, format=log_format)
elif args.verbose:
logging.basicConfig(level=logging.DEBUG, format=debug_log_format)
else:
logging.basicConfig(level=logging.INFO, format=log_format)
e6collector = E6Collector(jobs=jobs)
e6collector.mirror(tags, destination)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment