Skip to content

Instantly share code, notes, and snippets.

/e6collector.py Secret

Created July 2, 2017 23:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/8956030a367323d673943868bba3c076 to your computer and use it in GitHub Desktop.
Save anonymous/8956030a367323d673943868bba3c076 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
e6collector.
Python 3.x script that will download all images with a specific
tag, save them named ID-SOME-TAGS.EXTENSION and makes sure it doesn't
download the same image multiple times because the tags changed. Also
all tags will be written to tags.csv in the same folder. Using the
Ouroboros API described on https://e621.net/help/api
PROTIP: To view images with a certain tag try this on UNIX-like systems:
gthumb `for f in $(grep -i TAG tags.csv | cut -f 1 -d ","); do echo $f-*; done`
THE AUTHOR DOES NOT TAKE ANY RESPONSIBILITIES, THERE IS NO WARRANTY AND
YOU PROBABLY SHOULDN'T USE THIS IN A NUCLEAR POWER PLANT, JUST SAYING!
License: Public domain, do whatever.
Version 1.0 -- Initial release, if you can call it that
Version 1.0.1 -- Fixed Unicode problem on Windows
Version 1.0.2 -- Fixed API not working
Version 2.0.0 -- Rerwitten. Better logging, parallel downloads.
Version 2.0.1 -- Use `before_id` instead of `page` when listing posts
"""
import argparse
import csv
import glob
import logging
import os.path
import re
from functools import partial
from multiprocessing import Lock
from multiprocessing.dummy import Pool
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from xml.etree import ElementTree
_LOG = logging.getLogger('e6collector')
UNSAFE_TAG_CHARS = re.compile(r'[^0-9a-z_-]+')
def escape_tag(tag):
"""Strip unsafe characters from a tag."""
return UNSAFE_TAG_CHARS.sub('', tag)
class E6Collector:
"""Download tagged posts from e621, and write their tags to a CSV file."""
FILE_PATTERN = '{0:s}-{1:s}.{2:s}'
LIST_ENDPOINT = '/post/index.xml'
REQUEST_HEADERS = {'User-Agent': 'e6collector/2.0'}
URI_PREFIX = 'https://e621.net'
def __init__(self, jobs=None):
"""
:param jobs: Number of threads to use
:type jobs: int or None
"""
self.jobs = jobs
def _api_request(self, endpoint, query_params=None):
uri = self.URI_PREFIX + self.LIST_ENDPOINT
if query_params:
if 'tags' in query_params:
tags = query_params['tags']
if isinstance(tags, list):
query_params['tags'] = self._flatten_tags_param(tags)
query_string = urlencode(query_params)
uri += "?" + query_string
return Request(uri, headers=self.REQUEST_HEADERS)
def _flatten_tags_param(self, tags):
return ' '.join(tags)
def tagged_posts(self, tags):
"""
Fetch posts matching tags.
:param str tags:
:return: Generator of posts matching tags
:rtype: Generator[Element]
"""
post_count = 0
request_limit = 100
total_count = None
before_id = None
params = {'tags': tags, 'limit': request_limit, }
_LOG.debug("Finding posts tagged %r", tags)
while True:
# TODO: handle HTTP errors
_LOG.debug('Fetching posts %d to %d of %s', post_count,
post_count + request_limit - 1, total_count or
'unknown')
with urlopen(self._api_request(self.LIST_ENDPOINT,
params)) as request:
xmltree = ElementTree.parse(request)
total_count = int(xmltree.getroot().get('count'))
page_posts = xmltree.findall('post')
_LOG.debug("Got %d posts", len(page_posts))
for post in page_posts:
yield post
postid = int(post.find('id').text)
if before_id is None or postid < before_id:
before_id = postid
if len(page_posts) < request_limit:
_LOG.debug('No more posts')
break
post_count += len(page_posts)
params['before_id'] = before_id
def download_post(self, post):
"""
Fetch the image data for a post.
:param Element post:
:return: image data
:rtype: bytes
"""
postid = post.find('id').text
posturl = post.find('file_url').text
with urlopen(Request(posturl,
headers=self.REQUEST_HEADERS)) as request:
headers = request.info()
content_length = headers.get('content-length', '0')
content_length = int(content_length) / 1024.0
_LOG.info("Downloading post {id} ({kbytes:.2f}kB)".format(
id=postid,
kbytes=content_length))
imgdata = request.read()
return imgdata
def write_post(self, post, imgdata, destination):
"""
Write post image to disk, and ensure post tags are in tags.csv.
:param Element post:
:param bytes imgdata:
:param str destination: Path to folder in which to write
"""
postid = post.find('id').text
postext = post.find('file_ext').text
tags = filter(None, [
escape_tag(tag) for tag in post.find('tags').text.split(' ')
])
fname = self.FILE_PATTERN.format(postid, '-'.join(tags)[:190], postext)
path = os.path.join(destination, fname)
_LOG.debug("Writing post {id} to disk".format(id=postid))
with open(path, 'wb') as imgfile:
return imgfile.write(imgdata)
def is_tag_written(self, postid, tag_path, tag_file_lock=None):
"""
Check if the post is already in the tag file.
:param str postid:
:param str tag_path:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
"""
if not os.path.exists(tag_path):
return False
if tag_file_lock:
tag_file_lock.acquire()
try:
with open(tag_path, 'r') as tag_file:
for row in csv.reader(tag_file):
if str(row[0]) == postid:
return True
return False
finally:
if tag_file_lock:
tag_file_lock.release()
def update_tag_store(self, post, destination, tag_file_lock=None):
"""
Add post to the tag file.
:param Element post:
:param str destination:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
"""
postid = post.find('id').text
path = os.path.join(destination, 'tags.csv')
if self.is_tag_written(postid, path, tag_file_lock):
_LOG.debug("Post %s already in tag store", postid)
return
try:
source = post.find('source').text
except AttributeError:
try:
source = post.find('sources').text
except AttributeError:
source = ""
tags = post.find('tags').text.split(' ')
_LOG.debug("Writing tags for {id!r} ({tags}) to tags.csv".format(
id=postid,
tags=', '.join(tags), ))
try:
if tag_file_lock:
tag_file_lock.acquire()
with open(path, 'a', encoding='utf-8') as tagfile:
writer = csv.writer(tagfile, quoting=csv.QUOTE_MINIMAL)
row = [postid, source]
row.extend(tags)
writer.writerow(row)
finally:
if tag_file_lock:
tag_file_lock.release()
def fetch_post(self, post, destination, tag_file_lock=None):
"""
Download post, and write its tags to tag file.
:param Element post:
:param str destination:
:param tag_file_lock: Lock that gates access to the tag file
:type tag_file_lock: Lock or None
"""
postid = post.find('id').text
poststatus = post.find('status').text
if poststatus == 'deleted':
_LOG.warning(postid + ' was deleted')
return
matching_files = glob.glob(os.path.join(destination, postid + '-*'))
if matching_files:
_LOG.debug('Post %s already on disk: %r', postid, matching_files)
self.update_tag_store(post, destination, tag_file_lock)
return
try:
imgdata = self.download_post(post)
disk_size = self.write_post(post, imgdata, destination)
self.update_tag_store(post, destination, tag_file_lock)
except:
_LOG.exception("Failed to fetch post %s", postid)
raise
_LOG.info("Downloaded post %s (%.2fkB, tagged: %s)", postid,
disk_size / 1024.0,
', '.join(post.find('tags').text.split(' ')))
def mirror(self, tags, destination):
"""
Find posts matching tags, and download them to destination.
:param tags:
:type tags: str or list[str]
:param str destination:
"""
_LOG.info("Downloading posts tagged %r to %s", tags, destination)
posts = self.tagged_posts(tags)
os.makedirs(destination, mode=0o700, exist_ok=True)
if self.jobs:
try:
def _except(ex):
_LOG.error("%s occured in worker thread: %s", type(ex), ex)
pool = Pool(self.jobs)
tag_file_lock = Lock()
pool.imap_unordered(
partial(self.fetch_post,
destination=destination,
tag_file_lock=tag_file_lock),
posts)
pool.close()
except:
pool.terminate()
raise
finally:
pool.join()
else:
for post in posts:
self.fetch_post(post, destination)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Download files by tag from e621')
parser.add_argument('destination', help='Directory to store the files in')
parser.add_argument('tags',
help='Tags to look for. Try "fav:yourname"',
nargs='+')
parser.add_argument('--jobs',
'-j',
help='Downloads to run in parallel',
type=int,
default=10)
parser.add_argument('--verbose', '-v', action='store_true')
args = parser.parse_args()
tags = args.tags
jobs = args.jobs
if jobs == 0:
jobs = None
destination = args.destination
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
e6collector = E6Collector(jobs=jobs)
e6collector.mirror(tags, destination)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment