Skip to content

Instantly share code, notes, and snippets.

@vvern999
Last active November 7, 2023 14:36
Show Gist options
  • Save vvern999/8a02e54d45549b4786e61f0f5dbca924 to your computer and use it in GitHub Desktop.
Save vvern999/8a02e54d45549b4786e61f0f5dbca924 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# A Swiss-Army-Knife sort of script for dealing with Stable Diffusion caption
# files in .txt format.
# from https://github.com/space-nuko/sd-webui-utilities/tree/master
import argparse
#import collections
import os
import re
#import dotenv
import tqdm
#import PIL
from PIL import Image,ImageFile,ImageOps
#from pprint import pp
#import prompt_parser
import glob
#import sys
#import safetensors
import json
#import mmap
#import pprint
import os.path
from sacremoses import MosesPunctNormalizer
import unicodedata
import re
import shutil
import requests
import pandas
import warnings
import numpy as np
import cv2
import onnxruntime as rt
from huggingface_hub import hf_hub_download
import clip
from typing import Optional
import random
import math
import imgutils # pip install dghs-imgutils
import imgutils.restore
import scipy.signal
parser = argparse.ArgumentParser()
parser.add_argument('--recursive', '-r', action="store_true", help='Edit caption files recursively')
parser.add_argument('--escape', action="store_true", help='escape slashes')
subparsers = parser.add_subparsers(dest="command", help='sub-command help')
parser_fixup_tags = subparsers.add_parser('fixup', help='Fixup caption files, converting from gallery-dl format and normalizing UTF-8')
parser_fixup_tags.add_argument('path', type=str, help='Path to caption files')
parser_fixup_tags.add_argument('--double', '-d', action="store_true", help='Copy first tag')
parser_fixup_tags.add_argument('--rare', action="store_true", help='Remove rare tags - artists')
parser_fixup_tags.add_argument('--rare_all', action="store_true", help='Remove rare tags - all')
parser_fixup_tags.add_argument('--shorten', action="store_true", help='Consolidate/shorten')
parser_fixup_tags.add_argument('--remove_artists', type=float, default=None, help='Remove artist tags with probability 0-1.0')
parser_fixup_tags.add_argument('--remove_quality', action="store_true", help='Remove quality tags/scores')
parser_fixup_tags.add_argument('--token_len', action="store_true", help='Calculate token length')
parser_fixup_tags.add_argument('--append_filename', action="store_true", help='append filename to tags')
parser_fixup_tags.add_argument('--remove_long', type=float, default=None, help='remove long tags')
parser_add_tags = subparsers.add_parser('add', help='Add tags to captions (delimited by commas)')
parser_add_tags.add_argument('path', type=str, help='Path to caption files')
parser_add_tags.add_argument('tags', type=str, nargs='+', help='Tags to add')
parser_remove_tags = subparsers.add_parser('remove', help='Remove tags from captions (delimited by commas)')
parser_remove_tags.add_argument('path', type=str, help='Path to caption files')
parser_remove_tags.add_argument('tags', type=str, nargs='+', help='Tags to remove')
parser.add_argument('--no-convert-tags', '-n', action="store_false", dest="convert_tags", help='Do not convert tags to NAI format')
parser_replace_tag = subparsers.add_parser('replace', help='Replace tags in captions (delimited by commas)')
parser_replace_tag.add_argument('path', type=str, help='Path to caption files')
parser_replace_tag.add_argument('to_find', type=str, help='Tag to find')
parser_replace_tag.add_argument('to_replace', type=str, help='Tag to replace with')
parser_move_tags_to_front = subparsers.add_parser('move_to_front', help='Move tags in captions (delimited by commas) to front of list')
parser_move_tags_to_front.add_argument('path', type=str, help='Path to caption files')
parser_move_tags_to_front.add_argument('tags', type=str, nargs='+', help='Tags to move')
parser_move_categories_to_front = subparsers.add_parser('move_categories_to_front', help='Move danbooru tag categories in captions (delimited by commas) to front of list')
parser_move_categories_to_front.add_argument('path', type=str, help='Path to caption files')
parser_move_categories_to_front.add_argument('categories', type=str, nargs='+', help='Categories to move')
parser_strip_tag_suffix = subparsers.add_parser('strip_suffix', help='Strips a suffix from a tag ("neptune_(neptune_series)" -> "neptune")')
parser_strip_tag_suffix.add_argument('path', type=str, help='Path to caption files')
parser_strip_tag_suffix.add_argument('suffix', type=str, help='Suffix to find')
parser_validate = subparsers.add_parser('validate', help='Validate a dataset')
parser_validate.add_argument('path', type=str, help='Path to root of dataset folder')
parser_remove_high_ar = subparsers.add_parser('remove_high_ar', help='Remove images with high aspect ratios, low resolutions, etc.')
parser_remove_high_ar.add_argument('path', type=str, help='Path to root of dataset folder')
#parser_remove_high_ar.add_argument('--folder-name', '-n', type=str, help='Name of subfolder')
parser_recalc_multiply_from_folders = subparsers.add_parser('recalc_multiply_from_folders', help='Recalc multipliers, remove extra images ')
parser_recalc_multiply_from_folders.add_argument('path', type=str, help='Path to root of dataset folder')
parser_recalc_multiply_from_folders.add_argument('--move_extra_to', type=str, help='Path to split folder')
parser_recalc_multiply_from_folders.add_argument('--max_images', type=int, default=0, help='max images to keep for move_extra')
parser_gen_multiply = subparsers.add_parser('gen_multiply', help='Sort into folder, generate multiply.txt ')
parser_gen_multiply.add_argument('path', type=str, help='Path to root of dataset folder')
parser_gen_stats = subparsers.add_parser('gen_stats', help='generate tag count')
parser_gen_stats.add_argument('path', type=str, help='Path to root of dataset folder')
parser_name_stats = subparsers.add_parser('name_stats', help='generate tag count for first tag only')
parser_name_stats.add_argument('path', type=str, help='Path to root of dataset folder')
parser_clean = subparsers.add_parser('clean', help='adverse clean')
parser_clean.add_argument('path', type=str, help='Path to root of dataset folder')
parser_remove_blacklisted = subparsers.add_parser('remove_blacklisted', help=' ')
parser_remove_blacklisted.add_argument('path', type=str, help='Path to root of dataset folder')
parser_remove_blacklisted.add_argument('--tags', type=str, default='', nargs='+', help='Tags to remove')
parser_stats = subparsers.add_parser('stats', help='Show dataset image counts/repeats')
parser_stats.add_argument('path', type=str, help='Path to caption files')
parser_organize = subparsers.add_parser('organize', help='Move images with specified tags (delimited by commas) into a subfolder')
parser_organize.add_argument('path', type=str, help='Path to caption files')
parser_organize.add_argument('tags', type=str, nargs='+', help='Tags to move')
parser_organize.add_argument('--folder-name', '-n', type=str, help='Name of subfolder')
parser_organize.add_argument('--split-rest', '-s', action="store_true", help='Move all non-matching images into another folder')
parser.add_argument('--copy', action="store_true", help='copy instead of move')
parser_remove_low_ae_score = subparsers.add_parser('remove_low_ae_score', help=' ')
parser_remove_low_ae_score.add_argument('path', type=str, help='Path to caption files')
parser_remove_low_ae_score.add_argument('--folder-name', '-n', type=str, help='Name of subfolder')
parser_remove_low_ae_score.add_argument('--thr', type=float, default=0.0, help='Score threshold')
parser_remove_low_ae_score.add_argument('--type', type=str, default="wd", help='wd, clip or aa')
args = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 300000000
ImageFile.LOAD_TRUNCATED_IMAGES=True
gallery_dl_txt_re = re.compile(r'^(.*)\.[a-z]{3,4}\.txt')
IMAGE_EXTS = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".avif"]
repeats_folder_re = re.compile(r'^(\d+)_(.*)$')
ignored_tags = ("absurdres","looking at viewer","solo","high-ae","low-ae","simple background","white background","blush","smile","sitting","standing","sfw",
"holding","indoors","outdoors","#",""," ",
"bibi (tokoyami towa)","piyoko (uruha rushia)","bloop (gawr gura)","friend (nanashi mumei)","daifuku (yukihana lamy)","35p (sakura miko)","mr. squeaks (hakos baelz)","kintoki (sakura miko)","takodachi (ninomae ina'nis)",
"haaton (akai haato)","elfriend (shiranui flare)","udin (kureiji ollie)","bubba (watson amelia)","ao-chan (ninomae ina'nis)","don-chan (usada pekora)","nousagi (usada pekora)","shiranui (nakiri ayame)","crow (la+ darknesss)",
"pokobee","pikl (elira pendora)","#","@")
BAD_TAGS_REMOVE_TAG = ("absurdres", "highres", "translation request", "translated", "commentary", "commentary request", "commentary typo", "character request", "bad id", "bad link", "bad pixiv id", "bad twitter id", "bad tumblr id",
"bad deviantart id", "bad nicoseiga id", "md5 mismatch", "cosplay request", "artist request", "wide image", "author request",
"hololive","nijisanji","vshojo","phase connect","third-party source","official alternate hair length","official alternate hairstyle","official alternate costume","alternate costume",
"alternate hairstyle","","adapted costume","adapted_costume","k",""," ","alternate breast size")
BAD_TAGS_REMOVE_IMAGE = ("comic", "panels", "everyone", "sample watermark", "text focus", "tagme", "webtoon", "bad multiple views", "bara", "scat","jpeg artifacts","daz studio","daz3d")
TAG_ALIASES = {"1girls":"1girl","v":"peace sign, v", "ahe gao":"ahegao"}
BAD_GENERAL = ["comic","webtoon","bad multiple views","bara","scat"] # old
BAD_NSFW = ["mosaic censoring","heavily censored","censored"]
BAD_NSFW_EXCLUDE = ["convenient censoring","pointless censoring","heart censor","hair censor"]
HUMAN_TAGS = ["1girl","2girls","3girls","4girls","5girls","6+girls","1boy","2boys","3boys","4boys","5boys","6+boys"]
quality_tags = ("low quality","best quality","masterpiece")
MULTIPLY_LOW_THR = 8
MULTIPLY_MAX_MULTIPLIER = 3
MULTIPLY_LOW_TARGET = 100.0
MULTIPLY_HIGH_THR = 800.0
FILTER_TAG = "virtual youtuber"
BELOW_THR_MULTIPLIER = 1.0
MULTIPLY_MIN_MULTIPLIER = 1.0
RARE_TAG_THR = 2
RARE_ARTIST_TAG_THR = 10
MIN_RES = 576
MIN_PIXELS = 350000
def convert_tag(t):
#if "costume" in t:
# return t.replace(" ", "_").replace(" ", " ").replace("\\", "").replace("(", "").replace(")", "")
if "costume" in t:
return t.replace("_", " ").replace(" ", " ").replace("\\", "").replace("(", "").replace(")", "")
if len(t) <=3:
return t.replace(" ", "_").replace(" ", " ").replace("\\ ", "\\").replace("\\", "")
if args.escape:
if "\\" not in t:
if " " not in t:
return t.replace("_", " ").replace(" ", " ").replace("\\ ", "\\").replace("(", "\(").replace(")", "\)").replace("1girls", "1girl")
return t.replace("(", "\(").replace(")", "\)").replace(" ", " ").replace("1girls", "1girl")
else:
return t.replace(" ", " ").replace("\\ ", "\\").replace("1girls", "1girl")
# return t.replace("_", " ").replace("(", "\(").replace(")", "\)").replace(" ", " ")
#if "\\" not in t:
# return t.replace("(", "\(").replace(")", "\)").replace(" ", " ")
#else:
if " " not in t:
return t.replace("_", " ").replace(" ", " ").replace("\\ ", "\\").replace("\\", "").replace("1girls", "1girl").replace("__", "_")
else:
return t.replace(" ", " ").replace("\ ", "\\").replace("\\", "").replace("1girls", "1girl").replace("__", "_")
def get_caption_file_image(txt):
basename = os.path.splitext(txt)[0]
for ext in IMAGE_EXTS:
path = basename + ext
if os.path.isfile(path):
return path
return None
def get_caption_file_images(txt):
basename = os.path.splitext(txt)[0]
images = []
for ext in IMAGE_EXTS:
path = basename + ext
if os.path.isfile(path):
images.append(path)
return images
def fixup(args):
# rename gallery-dl-format .txt files (<...>.png.txt, etc.)
renamed = 0
total = 0
doubled = 0
long_tags = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
m = gallery_dl_txt_re.match(txt)
if m:
basename = m.groups(1)[0]
#print("RENAME: " + basename + ".txt")
shutil.move(txt, basename + ".txt")
renamed += 1
total += 1
print(f"Renamed {renamed}/{total} caption files.")
mpn = MosesPunctNormalizer()
if args.token_len:
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
num_edited = 0
token_lengths = {}
print("moving costume tags")
# join newlines, deduplicate tags, fix unicode chars, remove spaces and escape parens
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
s = f.read()
s = unicodedata.normalize("NFKC", mpn.normalize(s))
s = ", ".join(s.split("\n"))
these_tags = [convert_tag(t.strip().lower()) for t in s.split(",")]
#these_tags.reverse()
these_tags = move_costume_to_front(these_tags, 0)
#if contains_humans(these_tags):
# these_tags.append("human")
# these_tags.append("not furry")
if these_tags[-1].isdigit():
these_tags[-1] = get_quality_tag(int(these_tags[-1]), ("nsfw" in these_tags))
fixed_tags = []
for i in these_tags:
if not (args.remove_long is not None and len(i) > args.remove_long):
#if i.startswith("art by "):
# i = i.replace("art by ","by ")
if i.startswith("by:"):
i = i.replace("by:","by ")
if (i not in fixed_tags) and (i not in BAD_TAGS_REMOVE_TAG) and (not i.startswith("pool:")):
if not (args.remove_artists and i.startswith("by ") and random.random() <= args.remove_artists):
if not (args.remove_quality and i in quality_tags):
if i in TAG_ALIASES.keys():
fixed_tags.append(TAG_ALIASES[i])
else:
fixed_tags.append(i)
if args.shorten:
for short_tag in these_tags:
for long_tag in these_tags:
if long_tag.endswith(" " + short_tag):
try:
index = fixed_tags.index(short_tag)
fixed_tags.pop(index)
num_edited += 1
except:
continue
#print(f"tags: {these_tags}")
#print(f"fixed: {fixed_tags}")
if args.double and (FILTER_TAG in these_tags):
if these_tags[0] != 'nsfw' and these_tags[0] != '1girl' and these_tags[0] not in ignored_tags:
fixed_tags.append(these_tags[0])
doubled += 1
if args.token_len:
input_ids = tokenizer(", ".join(fixed_tags), padding="max_length", truncation=True, max_length=750, return_tensors="pt").input_ids
iids = [i for i in input_ids[0] if i != 49407]
chunk_len = math.ceil(len(iids) / 75)
if chunk_len not in token_lengths:
token_lengths[chunk_len] = 1
else:
token_lengths[chunk_len] += 1
if args.append_filename:
text = os.path.splitext(os.path.basename(txt))[0]
text.replace("1920x1080","")
text = re.sub('\d', '', text)
text = text.replace("-", " ").replace("_", " ").replace(" ", " ").replace("__", "_")
fixed_tags.append(text.strip())
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(fixed_tags))
print(f"Shortened {num_edited} tags. Doubled {doubled} tags")
if args.rare or args.rare_all:
remove_rare_tags()
if args.token_len:
token_lengths = dict(sorted(token_lengths.items()))
print("Token lengths:")
for k in token_lengths:
print(f"<{k * 75}: {token_lengths[k]}")
def get_quality_tag(score, is_nsfw):
if is_nsfw:
score = score / 2
if score >= 200:
return "masterpiece"
if score >= 100:
return "best quality"
if score < 5:
return "low quality"
return ""
def add(args):
tags = [convert_tag(t) for t in args.tags]
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for to_add in tags:
if to_add not in these_tags:
found = True
these_tags.append(to_add)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
if found:
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files.")
def remove(args):
tags = args.tags
if args.convert_tags:
tags = [convert_tag(t) for t in args.tags]
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for to_find in tags:
if to_find in these_tags:
found = True
index = these_tags.index(to_find)
these_tags.pop(index)
#if these_tags[-1] == tags[0]:
# found = True
# index = these_tags.index(tags[0])
# these_tags.pop(index)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
if found:
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files.")
def replace(args):
to_find = convert_tag(args.to_find)
to_replace = convert_tag(args.to_replace)
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
if to_find in these_tags:
assert to_replace not in these_tags
index = these_tags.index(to_find)
these_tags.pop(index)
these_tags.insert(index, to_replace)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files.")
def strip_suffix(args):
suffix = convert_tag(args.suffix)
modified = 0
total = 0
stripped_tags = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
new_tags = []
for t in these_tags:
if t.endswith(suffix):
found = True
t = t.removesuffix(suffix).strip()
stripped_tags += 1
new_tags.append(t)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(new_tags))
if found:
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files, {stripped_tags} tags stripped.")
def move_to_front(args):
tags = list(reversed([convert_tag(t) for t in args.tags]))
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for t in tags:
if t in these_tags:
found = True
these_tags.insert(0, these_tags.pop(these_tags.index(t)))
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
if found:
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files.")
CATEGORIES = {
"general": 0,
"artist": 1,
"copyright": 3,
"character": 4,
"meta": 5,
}
def to_danbooru_tag(t):
return t.strip().lower().replace(" ", "_").replace("\(", "(").replace("\)", ")")
def move_categories_to_front(args):
if not os.path.isfile("danbooru.csv"):
print("Downloading danbooru.csv tags list...")
url = "https://github.com/arenatemp/sd-tagging-helper/raw/master/danbooru.csv"
response = requests.get(url, stream=True)
with open("danbooru.csv", "wb") as handle:
for data in tqdm.tqdm(response.iter_content()):
handle.write(data)
print("Loading danbooru.csv tags list...")
danbooru_tags = pandas.read_csv("danbooru.csv", engine="pyarrow").set_axis(["tag", "tag_category"], axis=1).set_index("tag").to_dict("index")
order = [CATEGORIES[cat] for cat in reversed(args.categories)]
print("Done.")
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = list(to_danbooru_tag(t) for t in f.read().split(","))
for tag_category in order:
for t in these_tags:
this_category = danbooru_tags.get(t)
if this_category is not None:
this_category = this_category["tag_category"]
if tag_category == this_category:
found = True
these_tags.insert(0, these_tags.pop(these_tags.index(t)))
these_tags = [convert_tag(t) for t in these_tags]
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
if found:
modified += 1
total += 1
print(f"Updated {modified}/{total} caption files.")
def do_move(txt, img, outpath, args=None):
os.makedirs(outpath, exist_ok=True)
basename = os.path.splitext(os.path.basename(txt))[0]
img_ext = os.path.splitext(img)[1]
out_txt = os.path.join(outpath, basename + ".txt")
out_img = os.path.join(outpath, basename + img_ext)
# print(f"{img} -> {out_img}")
try:
if args.copy:
shutil.copy2(img, out_img)
shutil.copy2(txt, out_txt)
else:
shutil.move(img, out_img)
shutil.move(txt, out_txt)
except Exception as ex:
print(f"Exception: {ex}")
def organize(args):
tags = [convert_tag(t) for t in args.tags]
folder_name = args.folder_name or " ".join(args.tags)
outpath = os.path.join(args.path, folder_name.replace(":","_"))
# if os.path.exists(outpath):
# print(f"Error: Folder already exists - {outpath}")
# return 1
if args.split_rest:
split_path = os.path.join(args.path, "(rest)")
# if os.path.exists(split_path):
# print(f"Error: Folder already exists - {split_path}")
# return 1
modified = 0
total = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
these_tags = {t.strip().lower(): True for t in f.read().split(",")}
if any(t in these_tags for t in tags):
do_move(txt, img, outpath, args)
modified += 1
elif args.split_rest:
do_move(txt, img, split_path, args)
modified += 1
total += 1
print(f"Moved {modified}/{total} images and caption files to {outpath}.")
def try_split(str):
tags = []
#s = ''
if str.partition(", @")[1] != '':
s = str.partition(", @")[0].strip()
else:
#print(f"Failed to split: {str}")
s = str.partition(", ")[0].strip()
# s = s + ", "
tags = [t.strip().lower() for t in s.split(",") if t.strip().lower() not in ignored_tags]
#print(tags)
if len(tags) >= 1:
return tags[0]
else:
return None
def remove_rare_tags():
tags_dict = {}
rare_removed = 0
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for tag in these_tags:
if tag not in tags_dict:
tags_dict[tag] = 1
else:
tags_dict[tag] += 1
if args.rare_all:
tags_to_remove = [tag for tag in tags_dict if tags_dict[tag] < RARE_TAG_THR]
else:
tags_to_remove = [tag for tag in tags_dict if (tags_dict[tag] < RARE_ARTIST_TAG_THR and tag.startswith("by "))]
print("rare tags:", tags_to_remove)
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
found = False
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for to_find in tags_to_remove:
if to_find in these_tags:
found = True
index = these_tags.index(to_find)
these_tags.pop(index)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
if found:
rare_removed += 1
print(f"Removed {rare_removed} rare tags.")
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
CAPTION_EXTENSIONS = ['.txt', '.caption', '.yaml', '.yml']
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
for directory, _, filenames in os.walk(root_dir):
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
for image_filename in image_filenames:
image_path = os.path.join(directory, image_filename)
image_path_without_extension = os.path.splitext(image_path)[0]
caption_path = None
for caption_extension in CAPTION_EXTENSIONS:
possible_caption_path = image_path_without_extension + caption_extension
if os.path.exists(possible_caption_path):
caption_path = possible_caption_path
break
yield image_path, caption_path
def move_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
image_path = image_caption_pair[0]
caption_path = image_caption_pair[1]
# make target folder if necessary
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root))
target_folder = os.path.join(target_root, relative_folder)
os.makedirs(target_folder, exist_ok=True)
# move files
shutil.move(image_path, os.path.join(target_folder, os.path.basename(image_path)))
if caption_path is not None:
shutil.move(caption_path, os.path.join(target_folder, os.path.basename(caption_path)))
def move_extra(args):
max_images = MULTIPLY_HIGH_THR
if args.max_images > 0:
max_images = args.max_images
for dirname in tqdm.tqdm(os.listdir(args.path)):
path = os.path.join(args.path, dirname)
if os.path.isdir(path):
images = list(gather_captioned_images(path))
print(f"Found {len(images)} captioned images in {path}")
# split_count = math.ceil(len(images) * args.split_proportion)
split_count = 0
if len(images) > max_images:
split_count = int(len(images) - max_images)
if split_count > 0:
random.seed(1234)
random.shuffle(images)
val_split = images[0:split_count]
train_split = images[split_count:]
print(f"Split to 'train' set with {len(train_split)} images and 'extra' set with {len(val_split)}")
print(f"Moving 'extra' set to {args.move_extra_to}...")
for v in tqdm.tqdm(val_split):
move_captioned_image(v, args.path, args.move_extra_to)
elif len(images) < (MULTIPLY_LOW_THR / 2):
print(f"Below threshold. Moving to {args.move_extra_to}\\__below...")
for v in tqdm.tqdm(images):
move_captioned_image(v, args.path, os.path.join(args.move_extra_to, "__below"))
def get_ae_scores_with_aa(dir_images):
scores_dict = {k: None for k in dir_images}
print("Scoring images...")
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
model = rt.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
def predict(img):
img = img.astype(np.float32) / 255
s = 768
h, w = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
pred = model.run(None, {"img": img_input})[0].item()
return pred
for img in tqdm.tqdm(dir_images):
pil = Image.open(img[0]).convert("RGB")
img_a = np.array(pil)
try:
sc = predict(img_a)
if sc == None:
scores_dict[img] = 0
else:
scores_dict[img] = sc
except Exception as ex:
print(f"Exception: {ex} - {img[0]}")
scores_dict[img] = 0
return scores_dict
def move_extra_with_ae(args):
max_images = MULTIPLY_HIGH_THR
if args.max_images:
max_images = args.max_images
for dirname in tqdm.tqdm(os.listdir(args.path)):
path = os.path.join(args.path, dirname)
if os.path.isdir(path):
dir_images = list(gather_captioned_images(path))
print(f"Found {len(dir_images)} captioned images in {path}")
split_count = 0
if len(dir_images) > max_images:
split_count = int(len(dir_images) - max_images)
if split_count > 0:
scores_dict = get_ae_scores_with_aa(dir_images)
#print(scores_dict)
sorted_score = [(k, scores_dict[k]) for k in scores_dict]
sorted_score.sort(key=lambda sorted: sorted[1], reverse=False)
sorted_img = [a[0] for a in sorted_score]
print(f"Moving {split_count} images to {args.move_extra_to}. ")
for f in tqdm.tqdm(sorted_img):
if split_count <= 0:
break
move_captioned_image(f, args.path, args.move_extra_to)
split_count -= 1
elif len(dir_images) < (MULTIPLY_LOW_THR / 2):
print(f"Below threshold. Moving to {args.move_extra_to}\\__below...")
for v in tqdm.tqdm(dir_images):
move_captioned_image(v, args.path, os.path.join(args.move_extra_to, "__below"))
def name_stats(args):
tags_dict = {}
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
fts = try_split(f.read().strip().lower())
if fts != "" or fts != " ":
if fts not in tags_dict:
tags_dict[str(fts)] = 1
else:
tags_dict[str(fts)] += 1
sorted_ftags = [(k, tags_dict[k]) for k in tags_dict]
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True)
with open(os.path.join(args.path, "first_tags.txt"), "w", encoding='utf-8') as f:
json.dump(sorted_ftags, f, indent=2)
def gen_multiply(args):
tags_dict = {}
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
fts = try_split(f.read().strip().lower())
# fts = [f.read().split(",")[0].strip()]
#for tag in fts:
#if len(fts) > 1:
# if not any(s.startswith(tag) in s for s in (fts[0:fts.index(tag)] + fts[fts.index(tag)+1:])):
# continue
if fts is not None:
if fts not in tags_dict:
tags_dict[fts] = 1
else:
tags_dict[fts] += 1
sorted_ftags = [(k, tags_dict[k]) for k in tags_dict]
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True)
with open(os.path.join(args.path, "first_tags-before.txt"), "w", encoding='utf-8') as f:
json.dump(sorted_ftags, f, indent=2)
if "" in tags_dict:
tags_dict.pop("")
if "#" in tags_dict:
tags_dict.pop("#")
if "@" in tags_dict:
tags_dict.pop("#")
if " " in tags_dict:
tags_dict.pop(" ")
#if "group" in tags_dict:
# tags_dict.pop("group")
#if "multiple characters" in tags_dict:
# tags_dict.pop("multiple characters")
if "1girl" in tags_dict:
tags_dict.pop("1girl")
if "1boy" in tags_dict:
tags_dict.pop("1boy")
if "nsfw" in tags_dict:
tags_dict.pop("nsfw")
if "virtual youtuber" in tags_dict:
tags_dict.pop("virtual youtuber")
if "nijisanji" in tags_dict:
tags_dict.pop("nijisanji")
if "hololive" in tags_dict:
tags_dict.pop("hololive")
if "phase connect" in tags_dict:
tags_dict.pop("phase connect")
# td2 = tags_dict.copy()
# for tag in td2:
# if tags_dict[tag] < MULTIPLY_LOW_THR:
# tags_dict.pop(tag)
modified = 0
total = 0
zeroed = 0
# split_path = os.path.join(args.path, "zzz_no_first_tag")
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
for t in tags_dict.keys():
#if t in these_tags:
if t == these_tags[0]:
outpath = os.path.join(args.path, t.replace("\\", "_").replace(":", "_").replace("?", "_").replace("!", "_"))
do_move(txt, img, outpath, args)
mult = 1.0
if FILTER_TAG in these_tags:
if tags_dict[t] < MULTIPLY_LOW_THR:
mult = BELOW_THR_MULTIPLIER
zeroed += 1
elif tags_dict[t] < MULTIPLY_LOW_TARGET:
mult = min(MULTIPLY_LOW_TARGET / tags_dict[t], MULTIPLY_MAX_MULTIPLIER)
elif tags_dict[t] > MULTIPLY_HIGH_THR:
mult = max(MULTIPLY_HIGH_THR / tags_dict[t], MULTIPLY_MIN_MULTIPLIER)
mult_path = os.path.join(outpath, "multiply.txt")
if not os.path.exists(mult_path):
with open(mult_path, "w", encoding="utf-8") as f:
f.write(str(round(mult, 5)))
modified += 1
break
total += 1
print(f"Moved {modified}/{total} images and caption files. {zeroed} images with low count and tag '{FILTER_TAG}'")
def gen_stats(args):
tags_dict = {}
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))):
img = get_caption_file_image(txt)
if img:
with open(txt, "r", encoding="utf-8") as f:
these_tags = set([t.strip().lower() for t in f.read().split(",")])
if os.path.exists(os.path.join(os.path.split(txt)[0], "multiply.txt")):
mp = os.path.join(os.path.split(txt)[0], "multiply.txt")
with open(mp, "r", encoding="utf-8") as f:
mult = float(f.read())
else:
mult = 1.0
for tag in these_tags:
if tag not in tags_dict:
tags_dict[tag] = [mult, mult]
else:
tags_dict[tag] = [tags_dict[tag][0] + mult, (tags_dict[tag][1] + mult) / 2]
#sorted_ftags = [(k, round(tags_dict[k][0], 2), round(tags_dict[k][1], 2)) for k in tags_dict]
sorted_ftags = [(k, math.ceil(tags_dict[k][0])) for k in tags_dict]
sorted_costume_ftags = [(k, math.ceil(tags_dict[k][0])) for k in tags_dict if "costume" in k]
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True)
sorted_costume_ftags.sort(key=lambda sorted: sorted[1], reverse=True)
with open(os.path.join(args.path, "multiplied-tag-count.txt"), "w", encoding='utf-8') as f:
json.dump(sorted_ftags, f, indent=2)
with open(os.path.join(args.path, "multiplied-tag-count-costumes.txt"), "w", encoding='utf-8') as f:
json.dump(sorted_costume_ftags, f, indent=2)
def contains_costumes(tags, danbooru_tags):
for t in tags:
if "costume" in t:
return True
return False
def contains_humans(tags, danbooru_tags=None):
if "furry" in tags:
return False
if "human" in tags or "not furry" in tags:
return False # already tagged
for t in HUMAN_TAGS:
if t in tags:
return True
return False
def has_matching_costume(tags1, tags2):
for t in tags1:
if "costume" in t:
if t in tags2:
return True
return False
def move_name_to_front(tags, danbooru_tags):
these_tags = list(to_danbooru_tag(t) for t in tags)
for t in these_tags:
this_category = danbooru_tags.get(t)
if this_category is not None:
if this_category["tag_category"] == 4:
these_tags.insert(0, these_tags.pop(these_tags.index(t)))
these_tags = [convert_tag(t) for t in these_tags]
return these_tags
def move_costume_to_front(tags, danbooru_tags):
new_tags = tags.copy()
for t in tags:
if "costume)" in t or "costume" in t:
new_tags.insert(0, new_tags.pop(new_tags.index(t)))
#print(tags)
#print("NEW:", new_tags)
return new_tags
def contains_character_tag(tags, danbooru_tags):
these_tags = list(to_danbooru_tag(t) for t in tags)
for t in these_tags:
this_category = danbooru_tags.get(t)
if this_category is not None:
if this_category["tag_category"] == 4:
return True
return False
def recalc_multiply_from_folders(args):
dirs = []
modified = 0
if args.move_extra_to:
move_extra_with_ae(args)
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))):
img = get_caption_file_image(txt)
if img:
dir = os.path.split(txt)[0]
if dir not in dirs:
file_count = len(list(gather_captioned_images(dir)))
if file_count > 0:
mult = 1.0
if file_count < MULTIPLY_LOW_THR:
mult = BELOW_THR_MULTIPLIER
elif file_count < MULTIPLY_LOW_TARGET:
mult = min(MULTIPLY_LOW_TARGET / file_count, MULTIPLY_MAX_MULTIPLIER)
elif file_count > MULTIPLY_HIGH_THR:
mult = max(MULTIPLY_HIGH_THR / file_count, MULTIPLY_MIN_MULTIPLIER)
mult_path = os.path.join(dir, "multiply.txt")
with open(mult_path, "w", encoding="utf-8") as f:
f.write(str(round(mult, 4)))
modified += 1
else:
# os.remove(mult_path)
print("Empty folder: ", dir)
dirs.append(dir)
print(f"Changed {modified} multipliers")
def get_jpg_quality(pim: 'Image.Image') -> int:
"""
Implement quality computation following ImageMagick heuristic algorithm:
https://github.com/ImageMagick/ImageMagick/blob/7.1.0-57/coders/jpeg.c#L782
Usage:
```
pim = Img.open(...)
quality = get_jpg_quality(pim)
```
See also https://stackoverflow.com/questions/4354543/
"""
qsum = 0
qdict = pim.quantization
for i, qtable in qdict.items():
qsum += sum(qtable)
if len(qdict) >= 1:
qvalue = qdict[0][2]+qdict[0][53]
hash, sums = _HASH_1, _SUMS_1
if len(qdict) >= 2:
qvalue += qdict[1][0]+qdict[1][-1]
hash, sums = _HASH_2, _SUMS_2
for i in range(100):
if ((qvalue < hash[i]) and (qsum < sums[i])):
continue
if (((qvalue <= hash[i]) and (qsum <= sums[i])) or (i >= 50)):
return i+1
break
return -1
NUM_QUANT_TBLS = 4
DCTSIZE2 = 64
_HASH_2 = [ 1020, 1015, 932, 848, 780, 735, 702, 679, 660, 645,
632, 623, 613, 607, 600, 594, 589, 585, 581, 571,
555, 542, 529, 514, 494, 474, 457, 439, 424, 410,
397, 386, 373, 364, 351, 341, 334, 324, 317, 309,
299, 294, 287, 279, 274, 267, 262, 257, 251, 247,
243, 237, 232, 227, 222, 217, 213, 207, 202, 198,
192, 188, 183, 177, 173, 168, 163, 157, 153, 148,
143, 139, 132, 128, 125, 119, 115, 108, 104, 99,
94, 90, 84, 79, 74, 70, 64, 59, 55, 49,
45, 40, 34, 30, 25, 20, 15, 11, 6, 4,
0 ]
_SUMS_2 = [ 32640, 32635, 32266, 31495, 30665, 29804, 29146, 28599, 28104,
27670, 27225, 26725, 26210, 25716, 25240, 24789, 24373, 23946,
23572, 22846, 21801, 20842, 19949, 19121, 18386, 17651, 16998,
16349, 15800, 15247, 14783, 14321, 13859, 13535, 13081, 12702,
12423, 12056, 11779, 11513, 11135, 10955, 10676, 10392, 10208,
9928, 9747, 9564, 9369, 9193, 9017, 8822, 8639, 8458,
8270, 8084, 7896, 7710, 7527, 7347, 7156, 6977, 6788,
6607, 6422, 6236, 6054, 5867, 5684, 5495, 5305, 5128,
4945, 4751, 4638, 4442, 4248, 4065, 3888, 3698, 3509,
3326, 3139, 2957, 2775, 2586, 2405, 2216, 2037, 1846,
1666, 1483, 1297, 1109, 927, 735, 554, 375, 201,
128, 0 ]
_HASH_1 = [ 510, 505, 422, 380, 355, 338, 326, 318, 311, 305,
300, 297, 293, 291, 288, 286, 284, 283, 281, 280,
279, 278, 277, 273, 262, 251, 243, 233, 225, 218,
211, 205, 198, 193, 186, 181, 177, 172, 168, 164,
158, 156, 152, 148, 145, 142, 139, 136, 133, 131,
129, 126, 123, 120, 118, 115, 113, 110, 107, 105,
102, 100, 97, 94, 92, 89, 87, 83, 81, 79,
76, 74, 70, 68, 66, 63, 61, 57, 55, 52,
50, 48, 44, 42, 39, 37, 34, 31, 29, 26,
24, 21, 18, 16, 13, 11, 8, 6, 3, 2,
0 ]
_SUMS_1 = [
16320, 16315, 15946, 15277, 14655, 14073, 13623, 13230, 12859,
12560, 12240, 11861, 11456, 11081, 10714, 10360, 10027, 9679,
9368, 9056, 8680, 8331, 7995, 7668, 7376, 7084, 6823,
6562, 6345, 6125, 5939, 5756, 5571, 5421, 5240, 5086,
4976, 4829, 4719, 4616, 4463, 4393, 4280, 4166, 4092,
3980, 3909, 3835, 3755, 3688, 3621, 3541, 3467, 3396,
3323, 3247, 3170, 3096, 3021, 2952, 2874, 2804, 2727,
2657, 2583, 2509, 2437, 2362, 2290, 2211, 2136, 2068,
1996, 1915, 1858, 1773, 1692, 1620, 1552, 1477, 1398,
1326, 1251, 1179, 1109, 1031, 961, 884, 814, 736,
667, 592, 518, 441, 369, 292, 221, 151, 86,
64, 0 ]
def estimate_noise(I):
H, W = I.shape
M = [[1, -2, 1],
[-2, 4, -2],
[1, -2, 1]]
sigma = np.sum(np.sum(np.absolute(scipy.signal.convolve2d(I, M))))
sigma = sigma * math.sqrt(0.5 * math.pi) / (6 * (W-2) * (H-2))
return sigma
def clean(args):
from cv2.ximgproc import guidedFilter # pip install opencv-contrib-python
cleaned = 0
for ext in IMAGE_EXTS:
for i in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))):
#img = cv2.imread(i).astype(np.float32)
img = cv2.imdecode(np.fromfile(i, np.uint8), flags=cv2.IMREAD_COLOR)
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if estimate_noise(img_gray) > 2:
#print(f"{i}")
y = img.copy()
for _ in range(56): # orig: 64
y = cv2.bilateralFilter(y, 5, 8, 8)
for _ in range(4): # orig: 4
y = guidedFilter(img, y, 4, 16)
is_success, im_buf_arr = cv2.imencode(".png", img)
im_buf_arr.tofile(os.path.splitext(i)[0] + ".png")
#cv2.imwrite(os.path.splitext(i)[0] + ".png", y.clip(0, 255).astype(np.uint8))
cleaned += 1
if os.path.splitext(i)[1] != ".png":
os.remove(i)
print(f"Cleaned {cleaned} images.")
def remove_high_ar(args):
total = 0
hist = 0
converted = 0
resized = 0
lowquality = 0
small = 0
bad_ar = 0
to_remove = []
for ext in IMAGE_EXTS:
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))):
# txt = os.path.splitext(img)[0] + ".txt"
try:
with Image.open(img) as im:
ar = im.width / im.height
# if ar < 0.34 or ar > 2.99:
# if ar < 0.4 or ar > 2.5:
if ar < 0.49 or ar > 2.04:
to_remove.append(img)
bad_ar += 1
total += 1
#elif (im.width < MIN_RES) or (im.height < MIN_RES):
elif (im.width * im.height) < MIN_PIXELS:
to_remove.append(img)
small += 1
total += 1
else:
if (im.width * im.height) > 16000000:
if not (im.width == 4000 or im.height == 4000):
im = ImageOps.contain(im, (4000, 4000), Image.Resampling.LANCZOS)
im.convert("RGB").save(os.path.splitext(img)[0] + ".png", format='PNG', compress_level=4)
resized +=1
if os.path.splitext(img)[1] != ".png":
im.close()
os.remove(img)
# continue
elif im.format == 'JPEG' and (im.width < 1200 or im.height < 1200) and get_jpg_quality(im) < 80: # and (im.width < 1200 or im.height < 1200)
#to_remove.append(img)
im = imgutils.restore.scunet.restore_with_scunet(im)
im.save(os.path.splitext(img)[0] + ".png", format='PNG', compress_level=4)
if os.path.splitext(img)[1] != ".png":
im.close()
os.remove(img)
#total += 1
lowquality += 1
elif ('transparency' in im.info):
#im = im.convert("RGBA")
#new_image = Image.new("RGBA", im.size, (128,128,128))
#new_image.paste(im, mask=im)
im.convert("RGB").save(os.path.splitext(img)[0] + ".png", format="PNG")
#im = new_image
converted += 1
if os.path.splitext(img)[1] != ".png":
im.close()
os.remove(img)
# continue
#if im.histogram()[-1] > (0.6 * im.width * im.height):
# #new_img = ImageOps.autocontrast(im.convert("RGB"), preserve_tone=True)
# #new_img.save(f"{img}", format='JPEG', quality=100, subsampling=0)
# #im = new_img
# hist += 1
except Exception as ex:
print(f"Exception: {ex} {img} ")
# to_remove.append(img)
# total += 1
continue
for img in to_remove:
try:
os.remove(img)
os.remove(os.path.splitext(img)[0] + ".txt")
except Exception as ex:
print(f"Can't remove': {ex} {img} ")
print(f"Removed {total} images. Restored {lowquality} low quality JPEGs, Removed {small} small images, {bad_ar} bad AR images. Found {hist} low contrast images. Converted {converted} transparent images. Resized {resized} images. ")
def normalized(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def append_tag(txt, tag):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
if tag not in these_tags:
these_tags.append(tag)
with open(txt, "w", encoding="utf-8") as f:
f.write(", ".join(these_tags))
def remove_low_ae_score(args):
from transformers import pipeline
import torch
import pytorch_lightning as pl
import torch.nn as nn
class MLP(pl.LightningModule):
def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
super().__init__()
self.input_size = input_size
self.xcol = xcol
self.ycol = ycol
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
#nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
#nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
#nn.ReLU(),
nn.Linear(16, 1)
)
def forward(self, x):
return self.layers(x)
total = 0
high_ae = 0
low_ae = 0
device = "cuda"
# warnings.filterwarnings("ignore", category=UserWarning)
if args.type == "wd": # scores 0 - 1.0
pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=0, batch_size=2)
if args.thr == 0:
args.thr = 0.55
elif args.type == "clip": # scores 0 - 10.0
model = MLP(768)
s = torch.load("sac+logos+ava1-l14-linearMSE.pth")
model.load_state_dict(s)
model.to(device)
model.eval()
model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64
if args.thr == 0:
args.thr = 4.99
elif args.type == "aa": # scores 0 - 1.0
if args.thr == 0:
args.thr = 0.1
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx")
model = rt.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
def predict(img):
img = img.astype(np.float32) / 255
s = 768
h, w = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
pred = model.run(None, {"img": img_input})[0].item()
return pred
folder_name = args.folder_name
outpath = os.path.join(args.path, folder_name)
for ext in IMAGE_EXTS:
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))):
txt = os.path.splitext(img)[0] + ".txt"
pil = Image.open(img)
if args.type == "wd":
try:
data = pipe_aesthetic(pil, top_k=2)
except Exception as ex:
print(f"Exception: {ex} - {img}")
final = {}
for d in data:
final[d["label"]] = d["score"]
sc = final["aesthetic"]
elif args.type == "aa":
img_a = np.array(pil.convert("RGB"))
try:
sc = predict(img_a)
if sc < 0.2:
#append_tag(txt, "low-ae")
low_ae += 1
elif sc > 0.89:
#append_tag(txt, "high-ae")
high_ae += 1
except Exception as ex:
print(f"Exception: {ex} - {img}")
else:
try:
image = preprocess(pil).unsqueeze(0).to(device)
if args.type == "clip":
with torch.no_grad():
image_features = model2.encode_image(image)
im_emb_arr = normalized(image_features.cpu().detach().numpy() )
prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor))
sc = prediction[0][0].item()
# print(f"{img} - {sc}")
except Exception as ex:
print(f"Exception: {ex} - {img}")
if sc < args.thr:
pil.close()
do_move(txt, img, outpath, args)
total += 1
print(f"{total} images with low aesthetic score moved. {high_ae} images with high ae score, {low_ae} images with low ae score ")
warnings.filterwarnings("default")
def remove_blacklisted(args):
total = 0
files_to_remove = []
ttags = args.tags
if args.convert_tags:
ttags = [convert_tag(t) for t in args.tags]
print(ttags)
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))):
if get_caption_file_image(txt):
with open(txt, "r", encoding="utf-8") as f:
these_tags = [t.strip().lower() for t in f.read().split(",")]
if args.tags:
for btag in ttags:
if btag in these_tags:
files_to_remove.append(txt)
total += 1
break
else:
for btag in BAD_GENERAL:
if btag in these_tags:
files_to_remove.append(txt)
total += 1
break
if "nsfw" in these_tags:
exclude = True
for btag in BAD_NSFW:
if btag in these_tags:
exclude = False
for itag in BAD_NSFW_EXCLUDE:
if itag in these_tags:
exclude = True
if exclude == False:
files_to_remove.append(txt)
total += 1
for txt in files_to_remove:
try:
os.remove(get_caption_file_image(txt))
os.remove(txt)
except Exception as ex:
print(f"Exception: {ex}")
continue
print(f"Removed {total} images with blacklisted tags.")
def validate(args):
problems = []
total = 0
print("Validating folder names...")
for dirname in tqdm.tqdm(os.listdir(args.path)):
path = os.path.join(args.path, dirname)
if os.path.isdir(path):
# m = repeats_folder_re.match(dirname)
# if not m:
# problems.add((path, "Folder is not in \"5_concept\" format"))
# continue
img_count = len(list(glob.iglob(os.path.join(path, "*.txt"))))
if img_count == 0:
problems.append((path, "Folder contains no captions"))
print("Validating image files...")
for ext in IMAGE_EXTS:
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))):
txt = os.path.splitext(img)[0] + ".txt"
if not os.path.isfile(txt):
problems.append((img, "Image file is missing caption"))
#try:
# pil = Image.open(img)
# pil.load()
#except Exception as ex:
# problems.append((img, f"Failed to open image file: {ex}. Image removed"))
# os.remove(img)
# continue
print("Validating captions...")
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))):
total += 1
images = get_caption_file_images(txt)
if not images:
if (os.path.split(txt)[1] != "multiply.txt"):
problems.append((txt, "Caption file is missing corresponding image"))
os.rename(txt, txt + ".bak")
# print(f"Removed {txt}")
continue
elif len(images) > 1:
problems.append((txt, "Caption file has more than one corresponding image"))
continue
with open(txt, "r", encoding="utf-8") as f:
tag_string = f.read().strip()
if not tag_string:
problems.append((txt, "Caption file is empty"))
continue
if "\n" in tag_string:
problems.append((txt, "Caption file contains newlines"))
# if "_" in tag_string:
# problems.append((txt, "Caption file contains underscores"))
tags = {t.strip().lower(): True for t in tag_string.split(",")}
if not tags:
problems.append((txt, "Caption file has no tags"))
# elif any(not t for t in tags.keys()):
# problems.append((txt, "Caption file contains at least one blank tag"))
print("Validating latents...")
for npz in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.npz"), recursive=True))):
txt = os.path.splitext(npz)[0] + ".txt"
images = get_caption_file_images(npz)
if not os.path.isfile(txt):
problems.append((npz, "Latent is missing caption"))
os.rename(npz, npz + ".bak")
elif not images:
problems.append((npz, "Latent is missing corresponding image"))
os.rename(npz, npz + ".bak")
if problems:
for filename, problem in problems:
print(f"{filename} - {problem}")
return 1
print(f"No problems found for {total} image/caption pairs.")
return 0
def stats(args):
return
#problems = []
#total_images = 0
#total_seen = 0
#rows = [["folder name", "repeats", "image count", "total seen"]]
#for dirname in os.listdir(args.path):
# path = os.path.join(args.path, dirname)
# if os.path.isdir(path):
# m = repeats_folder_re.match(dirname)
# repeats, folder_name = int(m.group(1)), m.group(2)
# img_count = len(list(glob.iglob(os.path.join(path, "*.txt"))))
# rows.append([dirname, repeats, img_count, repeats * img_count])
# total_images += img_count
# total_seen += img_count * repeats
#rows.append(["(Total)", "", total_images, total_seen])
#col_width = max(len(str(word)) for row in rows for word in row) + 2
#for i, row in enumerate(rows):
# print("".join(str(word).ljust(col_width) for word in row))
# if i == 0:
# print(("=" * (col_width - 1) + " ") * len(rows[0]))
# elif i == len(rows) - 2:
# print(("-" * (col_width - 1) + " ") * len(rows[0]))
def main(args):
if args.command == "fixup":
return fixup(args)
elif args.command == "add":
return add(args)
elif args.command == "remove":
return remove(args)
elif args.command == "replace":
return replace(args)
elif args.command == "strip_suffix":
return strip_suffix(args)
elif args.command == "move_to_front":
return move_to_front(args)
elif args.command == "move_categories_to_front":
return move_categories_to_front(args)
elif args.command == "organize":
return organize(args)
elif args.command == "validate":
return validate(args)
elif args.command == "stats":
return stats(args)
elif args.command == "remove_high_ar":
return remove_high_ar(args)
elif args.command == "remove_blacklisted":
return remove_blacklisted(args)
elif args.command == "gen_multiply":
return gen_multiply(args)
elif args.command == "gen_stats":
return gen_stats(args)
elif args.command == "recalc_multiply_from_folders":
return recalc_multiply_from_folders(args)
elif args.command == "remove_low_ae_score":
return remove_low_ae_score(args)
elif args.command == "name_stats":
return name_stats(args)
elif args.command == "clean":
return clean(args)
else:
parser.print_help()
return 1
if __name__ == "__main__":
args = parser.parse_args()
parser.exit(main(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment