Last active
November 7, 2023 14:36
-
-
Save vvern999/8a02e54d45549b4786e61f0f5dbca924 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# A Swiss-Army-Knife sort of script for dealing with Stable Diffusion caption | |
# files in .txt format. | |
# from https://github.com/space-nuko/sd-webui-utilities/tree/master | |
import argparse | |
#import collections | |
import os | |
import re | |
#import dotenv | |
import tqdm | |
#import PIL | |
from PIL import Image,ImageFile,ImageOps | |
#from pprint import pp | |
#import prompt_parser | |
import glob | |
#import sys | |
#import safetensors | |
import json | |
#import mmap | |
#import pprint | |
import os.path | |
from sacremoses import MosesPunctNormalizer | |
import unicodedata | |
import re | |
import shutil | |
import requests | |
import pandas | |
import warnings | |
import numpy as np | |
import cv2 | |
import onnxruntime as rt | |
from huggingface_hub import hf_hub_download | |
import clip | |
from typing import Optional | |
import random | |
import math | |
import imgutils # pip install dghs-imgutils | |
import imgutils.restore | |
import scipy.signal | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--recursive', '-r', action="store_true", help='Edit caption files recursively') | |
parser.add_argument('--escape', action="store_true", help='escape slashes') | |
subparsers = parser.add_subparsers(dest="command", help='sub-command help') | |
parser_fixup_tags = subparsers.add_parser('fixup', help='Fixup caption files, converting from gallery-dl format and normalizing UTF-8') | |
parser_fixup_tags.add_argument('path', type=str, help='Path to caption files') | |
parser_fixup_tags.add_argument('--double', '-d', action="store_true", help='Copy first tag') | |
parser_fixup_tags.add_argument('--rare', action="store_true", help='Remove rare tags - artists') | |
parser_fixup_tags.add_argument('--rare_all', action="store_true", help='Remove rare tags - all') | |
parser_fixup_tags.add_argument('--shorten', action="store_true", help='Consolidate/shorten') | |
parser_fixup_tags.add_argument('--remove_artists', type=float, default=None, help='Remove artist tags with probability 0-1.0') | |
parser_fixup_tags.add_argument('--remove_quality', action="store_true", help='Remove quality tags/scores') | |
parser_fixup_tags.add_argument('--token_len', action="store_true", help='Calculate token length') | |
parser_fixup_tags.add_argument('--append_filename', action="store_true", help='append filename to tags') | |
parser_fixup_tags.add_argument('--remove_long', type=float, default=None, help='remove long tags') | |
parser_add_tags = subparsers.add_parser('add', help='Add tags to captions (delimited by commas)') | |
parser_add_tags.add_argument('path', type=str, help='Path to caption files') | |
parser_add_tags.add_argument('tags', type=str, nargs='+', help='Tags to add') | |
parser_remove_tags = subparsers.add_parser('remove', help='Remove tags from captions (delimited by commas)') | |
parser_remove_tags.add_argument('path', type=str, help='Path to caption files') | |
parser_remove_tags.add_argument('tags', type=str, nargs='+', help='Tags to remove') | |
parser.add_argument('--no-convert-tags', '-n', action="store_false", dest="convert_tags", help='Do not convert tags to NAI format') | |
parser_replace_tag = subparsers.add_parser('replace', help='Replace tags in captions (delimited by commas)') | |
parser_replace_tag.add_argument('path', type=str, help='Path to caption files') | |
parser_replace_tag.add_argument('to_find', type=str, help='Tag to find') | |
parser_replace_tag.add_argument('to_replace', type=str, help='Tag to replace with') | |
parser_move_tags_to_front = subparsers.add_parser('move_to_front', help='Move tags in captions (delimited by commas) to front of list') | |
parser_move_tags_to_front.add_argument('path', type=str, help='Path to caption files') | |
parser_move_tags_to_front.add_argument('tags', type=str, nargs='+', help='Tags to move') | |
parser_move_categories_to_front = subparsers.add_parser('move_categories_to_front', help='Move danbooru tag categories in captions (delimited by commas) to front of list') | |
parser_move_categories_to_front.add_argument('path', type=str, help='Path to caption files') | |
parser_move_categories_to_front.add_argument('categories', type=str, nargs='+', help='Categories to move') | |
parser_strip_tag_suffix = subparsers.add_parser('strip_suffix', help='Strips a suffix from a tag ("neptune_(neptune_series)" -> "neptune")') | |
parser_strip_tag_suffix.add_argument('path', type=str, help='Path to caption files') | |
parser_strip_tag_suffix.add_argument('suffix', type=str, help='Suffix to find') | |
parser_validate = subparsers.add_parser('validate', help='Validate a dataset') | |
parser_validate.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_remove_high_ar = subparsers.add_parser('remove_high_ar', help='Remove images with high aspect ratios, low resolutions, etc.') | |
parser_remove_high_ar.add_argument('path', type=str, help='Path to root of dataset folder') | |
#parser_remove_high_ar.add_argument('--folder-name', '-n', type=str, help='Name of subfolder') | |
parser_recalc_multiply_from_folders = subparsers.add_parser('recalc_multiply_from_folders', help='Recalc multipliers, remove extra images ') | |
parser_recalc_multiply_from_folders.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_recalc_multiply_from_folders.add_argument('--move_extra_to', type=str, help='Path to split folder') | |
parser_recalc_multiply_from_folders.add_argument('--max_images', type=int, default=0, help='max images to keep for move_extra') | |
parser_gen_multiply = subparsers.add_parser('gen_multiply', help='Sort into folder, generate multiply.txt ') | |
parser_gen_multiply.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_gen_stats = subparsers.add_parser('gen_stats', help='generate tag count') | |
parser_gen_stats.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_name_stats = subparsers.add_parser('name_stats', help='generate tag count for first tag only') | |
parser_name_stats.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_clean = subparsers.add_parser('clean', help='adverse clean') | |
parser_clean.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_remove_blacklisted = subparsers.add_parser('remove_blacklisted', help=' ') | |
parser_remove_blacklisted.add_argument('path', type=str, help='Path to root of dataset folder') | |
parser_remove_blacklisted.add_argument('--tags', type=str, default='', nargs='+', help='Tags to remove') | |
parser_stats = subparsers.add_parser('stats', help='Show dataset image counts/repeats') | |
parser_stats.add_argument('path', type=str, help='Path to caption files') | |
parser_organize = subparsers.add_parser('organize', help='Move images with specified tags (delimited by commas) into a subfolder') | |
parser_organize.add_argument('path', type=str, help='Path to caption files') | |
parser_organize.add_argument('tags', type=str, nargs='+', help='Tags to move') | |
parser_organize.add_argument('--folder-name', '-n', type=str, help='Name of subfolder') | |
parser_organize.add_argument('--split-rest', '-s', action="store_true", help='Move all non-matching images into another folder') | |
parser.add_argument('--copy', action="store_true", help='copy instead of move') | |
parser_remove_low_ae_score = subparsers.add_parser('remove_low_ae_score', help=' ') | |
parser_remove_low_ae_score.add_argument('path', type=str, help='Path to caption files') | |
parser_remove_low_ae_score.add_argument('--folder-name', '-n', type=str, help='Name of subfolder') | |
parser_remove_low_ae_score.add_argument('--thr', type=float, default=0.0, help='Score threshold') | |
parser_remove_low_ae_score.add_argument('--type', type=str, default="wd", help='wd, clip or aa') | |
args = parser.parse_args() | |
Image.MAX_IMAGE_PIXELS = 300000000 | |
ImageFile.LOAD_TRUNCATED_IMAGES=True | |
gallery_dl_txt_re = re.compile(r'^(.*)\.[a-z]{3,4}\.txt') | |
IMAGE_EXTS = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".avif"] | |
repeats_folder_re = re.compile(r'^(\d+)_(.*)$') | |
ignored_tags = ("absurdres","looking at viewer","solo","high-ae","low-ae","simple background","white background","blush","smile","sitting","standing","sfw", | |
"holding","indoors","outdoors","#",""," ", | |
"bibi (tokoyami towa)","piyoko (uruha rushia)","bloop (gawr gura)","friend (nanashi mumei)","daifuku (yukihana lamy)","35p (sakura miko)","mr. squeaks (hakos baelz)","kintoki (sakura miko)","takodachi (ninomae ina'nis)", | |
"haaton (akai haato)","elfriend (shiranui flare)","udin (kureiji ollie)","bubba (watson amelia)","ao-chan (ninomae ina'nis)","don-chan (usada pekora)","nousagi (usada pekora)","shiranui (nakiri ayame)","crow (la+ darknesss)", | |
"pokobee","pikl (elira pendora)","#","@") | |
BAD_TAGS_REMOVE_TAG = ("absurdres", "highres", "translation request", "translated", "commentary", "commentary request", "commentary typo", "character request", "bad id", "bad link", "bad pixiv id", "bad twitter id", "bad tumblr id", | |
"bad deviantart id", "bad nicoseiga id", "md5 mismatch", "cosplay request", "artist request", "wide image", "author request", | |
"hololive","nijisanji","vshojo","phase connect","third-party source","official alternate hair length","official alternate hairstyle","official alternate costume","alternate costume", | |
"alternate hairstyle","","adapted costume","adapted_costume","k",""," ","alternate breast size") | |
BAD_TAGS_REMOVE_IMAGE = ("comic", "panels", "everyone", "sample watermark", "text focus", "tagme", "webtoon", "bad multiple views", "bara", "scat","jpeg artifacts","daz studio","daz3d") | |
TAG_ALIASES = {"1girls":"1girl","v":"peace sign, v", "ahe gao":"ahegao"} | |
BAD_GENERAL = ["comic","webtoon","bad multiple views","bara","scat"] # old | |
BAD_NSFW = ["mosaic censoring","heavily censored","censored"] | |
BAD_NSFW_EXCLUDE = ["convenient censoring","pointless censoring","heart censor","hair censor"] | |
HUMAN_TAGS = ["1girl","2girls","3girls","4girls","5girls","6+girls","1boy","2boys","3boys","4boys","5boys","6+boys"] | |
quality_tags = ("low quality","best quality","masterpiece") | |
MULTIPLY_LOW_THR = 8 | |
MULTIPLY_MAX_MULTIPLIER = 3 | |
MULTIPLY_LOW_TARGET = 100.0 | |
MULTIPLY_HIGH_THR = 800.0 | |
FILTER_TAG = "virtual youtuber" | |
BELOW_THR_MULTIPLIER = 1.0 | |
MULTIPLY_MIN_MULTIPLIER = 1.0 | |
RARE_TAG_THR = 2 | |
RARE_ARTIST_TAG_THR = 10 | |
MIN_RES = 576 | |
MIN_PIXELS = 350000 | |
def convert_tag(t): | |
#if "costume" in t: | |
# return t.replace(" ", "_").replace(" ", " ").replace("\\", "").replace("(", "").replace(")", "") | |
if "costume" in t: | |
return t.replace("_", " ").replace(" ", " ").replace("\\", "").replace("(", "").replace(")", "") | |
if len(t) <=3: | |
return t.replace(" ", "_").replace(" ", " ").replace("\\ ", "\\").replace("\\", "") | |
if args.escape: | |
if "\\" not in t: | |
if " " not in t: | |
return t.replace("_", " ").replace(" ", " ").replace("\\ ", "\\").replace("(", "\(").replace(")", "\)").replace("1girls", "1girl") | |
return t.replace("(", "\(").replace(")", "\)").replace(" ", " ").replace("1girls", "1girl") | |
else: | |
return t.replace(" ", " ").replace("\\ ", "\\").replace("1girls", "1girl") | |
# return t.replace("_", " ").replace("(", "\(").replace(")", "\)").replace(" ", " ") | |
#if "\\" not in t: | |
# return t.replace("(", "\(").replace(")", "\)").replace(" ", " ") | |
#else: | |
if " " not in t: | |
return t.replace("_", " ").replace(" ", " ").replace("\\ ", "\\").replace("\\", "").replace("1girls", "1girl").replace("__", "_") | |
else: | |
return t.replace(" ", " ").replace("\ ", "\\").replace("\\", "").replace("1girls", "1girl").replace("__", "_") | |
def get_caption_file_image(txt): | |
basename = os.path.splitext(txt)[0] | |
for ext in IMAGE_EXTS: | |
path = basename + ext | |
if os.path.isfile(path): | |
return path | |
return None | |
def get_caption_file_images(txt): | |
basename = os.path.splitext(txt)[0] | |
images = [] | |
for ext in IMAGE_EXTS: | |
path = basename + ext | |
if os.path.isfile(path): | |
images.append(path) | |
return images | |
def fixup(args): | |
# rename gallery-dl-format .txt files (<...>.png.txt, etc.) | |
renamed = 0 | |
total = 0 | |
doubled = 0 | |
long_tags = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
m = gallery_dl_txt_re.match(txt) | |
if m: | |
basename = m.groups(1)[0] | |
#print("RENAME: " + basename + ".txt") | |
shutil.move(txt, basename + ".txt") | |
renamed += 1 | |
total += 1 | |
print(f"Renamed {renamed}/{total} caption files.") | |
mpn = MosesPunctNormalizer() | |
if args.token_len: | |
from transformers import CLIPTokenizer | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
num_edited = 0 | |
token_lengths = {} | |
print("moving costume tags") | |
# join newlines, deduplicate tags, fix unicode chars, remove spaces and escape parens | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
s = f.read() | |
s = unicodedata.normalize("NFKC", mpn.normalize(s)) | |
s = ", ".join(s.split("\n")) | |
these_tags = [convert_tag(t.strip().lower()) for t in s.split(",")] | |
#these_tags.reverse() | |
these_tags = move_costume_to_front(these_tags, 0) | |
#if contains_humans(these_tags): | |
# these_tags.append("human") | |
# these_tags.append("not furry") | |
if these_tags[-1].isdigit(): | |
these_tags[-1] = get_quality_tag(int(these_tags[-1]), ("nsfw" in these_tags)) | |
fixed_tags = [] | |
for i in these_tags: | |
if not (args.remove_long is not None and len(i) > args.remove_long): | |
#if i.startswith("art by "): | |
# i = i.replace("art by ","by ") | |
if i.startswith("by:"): | |
i = i.replace("by:","by ") | |
if (i not in fixed_tags) and (i not in BAD_TAGS_REMOVE_TAG) and (not i.startswith("pool:")): | |
if not (args.remove_artists and i.startswith("by ") and random.random() <= args.remove_artists): | |
if not (args.remove_quality and i in quality_tags): | |
if i in TAG_ALIASES.keys(): | |
fixed_tags.append(TAG_ALIASES[i]) | |
else: | |
fixed_tags.append(i) | |
if args.shorten: | |
for short_tag in these_tags: | |
for long_tag in these_tags: | |
if long_tag.endswith(" " + short_tag): | |
try: | |
index = fixed_tags.index(short_tag) | |
fixed_tags.pop(index) | |
num_edited += 1 | |
except: | |
continue | |
#print(f"tags: {these_tags}") | |
#print(f"fixed: {fixed_tags}") | |
if args.double and (FILTER_TAG in these_tags): | |
if these_tags[0] != 'nsfw' and these_tags[0] != '1girl' and these_tags[0] not in ignored_tags: | |
fixed_tags.append(these_tags[0]) | |
doubled += 1 | |
if args.token_len: | |
input_ids = tokenizer(", ".join(fixed_tags), padding="max_length", truncation=True, max_length=750, return_tensors="pt").input_ids | |
iids = [i for i in input_ids[0] if i != 49407] | |
chunk_len = math.ceil(len(iids) / 75) | |
if chunk_len not in token_lengths: | |
token_lengths[chunk_len] = 1 | |
else: | |
token_lengths[chunk_len] += 1 | |
if args.append_filename: | |
text = os.path.splitext(os.path.basename(txt))[0] | |
text.replace("1920x1080","") | |
text = re.sub('\d', '', text) | |
text = text.replace("-", " ").replace("_", " ").replace(" ", " ").replace("__", "_") | |
fixed_tags.append(text.strip()) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(fixed_tags)) | |
print(f"Shortened {num_edited} tags. Doubled {doubled} tags") | |
if args.rare or args.rare_all: | |
remove_rare_tags() | |
if args.token_len: | |
token_lengths = dict(sorted(token_lengths.items())) | |
print("Token lengths:") | |
for k in token_lengths: | |
print(f"<{k * 75}: {token_lengths[k]}") | |
def get_quality_tag(score, is_nsfw): | |
if is_nsfw: | |
score = score / 2 | |
if score >= 200: | |
return "masterpiece" | |
if score >= 100: | |
return "best quality" | |
if score < 5: | |
return "low quality" | |
return "" | |
def add(args): | |
tags = [convert_tag(t) for t in args.tags] | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for to_add in tags: | |
if to_add not in these_tags: | |
found = True | |
these_tags.append(to_add) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
if found: | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files.") | |
def remove(args): | |
tags = args.tags | |
if args.convert_tags: | |
tags = [convert_tag(t) for t in args.tags] | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for to_find in tags: | |
if to_find in these_tags: | |
found = True | |
index = these_tags.index(to_find) | |
these_tags.pop(index) | |
#if these_tags[-1] == tags[0]: | |
# found = True | |
# index = these_tags.index(tags[0]) | |
# these_tags.pop(index) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
if found: | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files.") | |
def replace(args): | |
to_find = convert_tag(args.to_find) | |
to_replace = convert_tag(args.to_replace) | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
if to_find in these_tags: | |
assert to_replace not in these_tags | |
index = these_tags.index(to_find) | |
these_tags.pop(index) | |
these_tags.insert(index, to_replace) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files.") | |
def strip_suffix(args): | |
suffix = convert_tag(args.suffix) | |
modified = 0 | |
total = 0 | |
stripped_tags = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
new_tags = [] | |
for t in these_tags: | |
if t.endswith(suffix): | |
found = True | |
t = t.removesuffix(suffix).strip() | |
stripped_tags += 1 | |
new_tags.append(t) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(new_tags)) | |
if found: | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files, {stripped_tags} tags stripped.") | |
def move_to_front(args): | |
tags = list(reversed([convert_tag(t) for t in args.tags])) | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for t in tags: | |
if t in these_tags: | |
found = True | |
these_tags.insert(0, these_tags.pop(these_tags.index(t))) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
if found: | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files.") | |
CATEGORIES = { | |
"general": 0, | |
"artist": 1, | |
"copyright": 3, | |
"character": 4, | |
"meta": 5, | |
} | |
def to_danbooru_tag(t): | |
return t.strip().lower().replace(" ", "_").replace("\(", "(").replace("\)", ")") | |
def move_categories_to_front(args): | |
if not os.path.isfile("danbooru.csv"): | |
print("Downloading danbooru.csv tags list...") | |
url = "https://github.com/arenatemp/sd-tagging-helper/raw/master/danbooru.csv" | |
response = requests.get(url, stream=True) | |
with open("danbooru.csv", "wb") as handle: | |
for data in tqdm.tqdm(response.iter_content()): | |
handle.write(data) | |
print("Loading danbooru.csv tags list...") | |
danbooru_tags = pandas.read_csv("danbooru.csv", engine="pyarrow").set_axis(["tag", "tag_category"], axis=1).set_index("tag").to_dict("index") | |
order = [CATEGORIES[cat] for cat in reversed(args.categories)] | |
print("Done.") | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = list(to_danbooru_tag(t) for t in f.read().split(",")) | |
for tag_category in order: | |
for t in these_tags: | |
this_category = danbooru_tags.get(t) | |
if this_category is not None: | |
this_category = this_category["tag_category"] | |
if tag_category == this_category: | |
found = True | |
these_tags.insert(0, these_tags.pop(these_tags.index(t))) | |
these_tags = [convert_tag(t) for t in these_tags] | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
if found: | |
modified += 1 | |
total += 1 | |
print(f"Updated {modified}/{total} caption files.") | |
def do_move(txt, img, outpath, args=None): | |
os.makedirs(outpath, exist_ok=True) | |
basename = os.path.splitext(os.path.basename(txt))[0] | |
img_ext = os.path.splitext(img)[1] | |
out_txt = os.path.join(outpath, basename + ".txt") | |
out_img = os.path.join(outpath, basename + img_ext) | |
# print(f"{img} -> {out_img}") | |
try: | |
if args.copy: | |
shutil.copy2(img, out_img) | |
shutil.copy2(txt, out_txt) | |
else: | |
shutil.move(img, out_img) | |
shutil.move(txt, out_txt) | |
except Exception as ex: | |
print(f"Exception: {ex}") | |
def organize(args): | |
tags = [convert_tag(t) for t in args.tags] | |
folder_name = args.folder_name or " ".join(args.tags) | |
outpath = os.path.join(args.path, folder_name.replace(":","_")) | |
# if os.path.exists(outpath): | |
# print(f"Error: Folder already exists - {outpath}") | |
# return 1 | |
if args.split_rest: | |
split_path = os.path.join(args.path, "(rest)") | |
# if os.path.exists(split_path): | |
# print(f"Error: Folder already exists - {split_path}") | |
# return 1 | |
modified = 0 | |
total = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = {t.strip().lower(): True for t in f.read().split(",")} | |
if any(t in these_tags for t in tags): | |
do_move(txt, img, outpath, args) | |
modified += 1 | |
elif args.split_rest: | |
do_move(txt, img, split_path, args) | |
modified += 1 | |
total += 1 | |
print(f"Moved {modified}/{total} images and caption files to {outpath}.") | |
def try_split(str): | |
tags = [] | |
#s = '' | |
if str.partition(", @")[1] != '': | |
s = str.partition(", @")[0].strip() | |
else: | |
#print(f"Failed to split: {str}") | |
s = str.partition(", ")[0].strip() | |
# s = s + ", " | |
tags = [t.strip().lower() for t in s.split(",") if t.strip().lower() not in ignored_tags] | |
#print(tags) | |
if len(tags) >= 1: | |
return tags[0] | |
else: | |
return None | |
def remove_rare_tags(): | |
tags_dict = {} | |
rare_removed = 0 | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for tag in these_tags: | |
if tag not in tags_dict: | |
tags_dict[tag] = 1 | |
else: | |
tags_dict[tag] += 1 | |
if args.rare_all: | |
tags_to_remove = [tag for tag in tags_dict if tags_dict[tag] < RARE_TAG_THR] | |
else: | |
tags_to_remove = [tag for tag in tags_dict if (tags_dict[tag] < RARE_ARTIST_TAG_THR and tag.startswith("by "))] | |
print("rare tags:", tags_to_remove) | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
found = False | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for to_find in tags_to_remove: | |
if to_find in these_tags: | |
found = True | |
index = these_tags.index(to_find) | |
these_tags.pop(index) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
if found: | |
rare_removed += 1 | |
print(f"Removed {rare_removed} rare tags.") | |
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'] | |
CAPTION_EXTENSIONS = ['.txt', '.caption', '.yaml', '.yml'] | |
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]: | |
for directory, _, filenames in os.walk(root_dir): | |
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS] | |
for image_filename in image_filenames: | |
image_path = os.path.join(directory, image_filename) | |
image_path_without_extension = os.path.splitext(image_path)[0] | |
caption_path = None | |
for caption_extension in CAPTION_EXTENSIONS: | |
possible_caption_path = image_path_without_extension + caption_extension | |
if os.path.exists(possible_caption_path): | |
caption_path = possible_caption_path | |
break | |
yield image_path, caption_path | |
def move_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str): | |
image_path = image_caption_pair[0] | |
caption_path = image_caption_pair[1] | |
# make target folder if necessary | |
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root)) | |
target_folder = os.path.join(target_root, relative_folder) | |
os.makedirs(target_folder, exist_ok=True) | |
# move files | |
shutil.move(image_path, os.path.join(target_folder, os.path.basename(image_path))) | |
if caption_path is not None: | |
shutil.move(caption_path, os.path.join(target_folder, os.path.basename(caption_path))) | |
def move_extra(args): | |
max_images = MULTIPLY_HIGH_THR | |
if args.max_images > 0: | |
max_images = args.max_images | |
for dirname in tqdm.tqdm(os.listdir(args.path)): | |
path = os.path.join(args.path, dirname) | |
if os.path.isdir(path): | |
images = list(gather_captioned_images(path)) | |
print(f"Found {len(images)} captioned images in {path}") | |
# split_count = math.ceil(len(images) * args.split_proportion) | |
split_count = 0 | |
if len(images) > max_images: | |
split_count = int(len(images) - max_images) | |
if split_count > 0: | |
random.seed(1234) | |
random.shuffle(images) | |
val_split = images[0:split_count] | |
train_split = images[split_count:] | |
print(f"Split to 'train' set with {len(train_split)} images and 'extra' set with {len(val_split)}") | |
print(f"Moving 'extra' set to {args.move_extra_to}...") | |
for v in tqdm.tqdm(val_split): | |
move_captioned_image(v, args.path, args.move_extra_to) | |
elif len(images) < (MULTIPLY_LOW_THR / 2): | |
print(f"Below threshold. Moving to {args.move_extra_to}\\__below...") | |
for v in tqdm.tqdm(images): | |
move_captioned_image(v, args.path, os.path.join(args.move_extra_to, "__below")) | |
def get_ae_scores_with_aa(dir_images): | |
scores_dict = {k: None for k in dir_images} | |
print("Scoring images...") | |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") | |
model = rt.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
def predict(img): | |
img = img.astype(np.float32) / 255 | |
s = 768 | |
h, w = img.shape[:-1] | |
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) | |
ph, pw = s - h, s - w | |
img_input = np.zeros([s, s, 3], dtype=np.float32) | |
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) | |
img_input = np.transpose(img_input, (2, 0, 1)) | |
img_input = img_input[np.newaxis, :] | |
pred = model.run(None, {"img": img_input})[0].item() | |
return pred | |
for img in tqdm.tqdm(dir_images): | |
pil = Image.open(img[0]).convert("RGB") | |
img_a = np.array(pil) | |
try: | |
sc = predict(img_a) | |
if sc == None: | |
scores_dict[img] = 0 | |
else: | |
scores_dict[img] = sc | |
except Exception as ex: | |
print(f"Exception: {ex} - {img[0]}") | |
scores_dict[img] = 0 | |
return scores_dict | |
def move_extra_with_ae(args): | |
max_images = MULTIPLY_HIGH_THR | |
if args.max_images: | |
max_images = args.max_images | |
for dirname in tqdm.tqdm(os.listdir(args.path)): | |
path = os.path.join(args.path, dirname) | |
if os.path.isdir(path): | |
dir_images = list(gather_captioned_images(path)) | |
print(f"Found {len(dir_images)} captioned images in {path}") | |
split_count = 0 | |
if len(dir_images) > max_images: | |
split_count = int(len(dir_images) - max_images) | |
if split_count > 0: | |
scores_dict = get_ae_scores_with_aa(dir_images) | |
#print(scores_dict) | |
sorted_score = [(k, scores_dict[k]) for k in scores_dict] | |
sorted_score.sort(key=lambda sorted: sorted[1], reverse=False) | |
sorted_img = [a[0] for a in sorted_score] | |
print(f"Moving {split_count} images to {args.move_extra_to}. ") | |
for f in tqdm.tqdm(sorted_img): | |
if split_count <= 0: | |
break | |
move_captioned_image(f, args.path, args.move_extra_to) | |
split_count -= 1 | |
elif len(dir_images) < (MULTIPLY_LOW_THR / 2): | |
print(f"Below threshold. Moving to {args.move_extra_to}\\__below...") | |
for v in tqdm.tqdm(dir_images): | |
move_captioned_image(v, args.path, os.path.join(args.move_extra_to, "__below")) | |
def name_stats(args): | |
tags_dict = {} | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
fts = try_split(f.read().strip().lower()) | |
if fts != "" or fts != " ": | |
if fts not in tags_dict: | |
tags_dict[str(fts)] = 1 | |
else: | |
tags_dict[str(fts)] += 1 | |
sorted_ftags = [(k, tags_dict[k]) for k in tags_dict] | |
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True) | |
with open(os.path.join(args.path, "first_tags.txt"), "w", encoding='utf-8') as f: | |
json.dump(sorted_ftags, f, indent=2) | |
def gen_multiply(args): | |
tags_dict = {} | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
fts = try_split(f.read().strip().lower()) | |
# fts = [f.read().split(",")[0].strip()] | |
#for tag in fts: | |
#if len(fts) > 1: | |
# if not any(s.startswith(tag) in s for s in (fts[0:fts.index(tag)] + fts[fts.index(tag)+1:])): | |
# continue | |
if fts is not None: | |
if fts not in tags_dict: | |
tags_dict[fts] = 1 | |
else: | |
tags_dict[fts] += 1 | |
sorted_ftags = [(k, tags_dict[k]) for k in tags_dict] | |
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True) | |
with open(os.path.join(args.path, "first_tags-before.txt"), "w", encoding='utf-8') as f: | |
json.dump(sorted_ftags, f, indent=2) | |
if "" in tags_dict: | |
tags_dict.pop("") | |
if "#" in tags_dict: | |
tags_dict.pop("#") | |
if "@" in tags_dict: | |
tags_dict.pop("#") | |
if " " in tags_dict: | |
tags_dict.pop(" ") | |
#if "group" in tags_dict: | |
# tags_dict.pop("group") | |
#if "multiple characters" in tags_dict: | |
# tags_dict.pop("multiple characters") | |
if "1girl" in tags_dict: | |
tags_dict.pop("1girl") | |
if "1boy" in tags_dict: | |
tags_dict.pop("1boy") | |
if "nsfw" in tags_dict: | |
tags_dict.pop("nsfw") | |
if "virtual youtuber" in tags_dict: | |
tags_dict.pop("virtual youtuber") | |
if "nijisanji" in tags_dict: | |
tags_dict.pop("nijisanji") | |
if "hololive" in tags_dict: | |
tags_dict.pop("hololive") | |
if "phase connect" in tags_dict: | |
tags_dict.pop("phase connect") | |
# td2 = tags_dict.copy() | |
# for tag in td2: | |
# if tags_dict[tag] < MULTIPLY_LOW_THR: | |
# tags_dict.pop(tag) | |
modified = 0 | |
total = 0 | |
zeroed = 0 | |
# split_path = os.path.join(args.path, "zzz_no_first_tag") | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
for t in tags_dict.keys(): | |
#if t in these_tags: | |
if t == these_tags[0]: | |
outpath = os.path.join(args.path, t.replace("\\", "_").replace(":", "_").replace("?", "_").replace("!", "_")) | |
do_move(txt, img, outpath, args) | |
mult = 1.0 | |
if FILTER_TAG in these_tags: | |
if tags_dict[t] < MULTIPLY_LOW_THR: | |
mult = BELOW_THR_MULTIPLIER | |
zeroed += 1 | |
elif tags_dict[t] < MULTIPLY_LOW_TARGET: | |
mult = min(MULTIPLY_LOW_TARGET / tags_dict[t], MULTIPLY_MAX_MULTIPLIER) | |
elif tags_dict[t] > MULTIPLY_HIGH_THR: | |
mult = max(MULTIPLY_HIGH_THR / tags_dict[t], MULTIPLY_MIN_MULTIPLIER) | |
mult_path = os.path.join(outpath, "multiply.txt") | |
if not os.path.exists(mult_path): | |
with open(mult_path, "w", encoding="utf-8") as f: | |
f.write(str(round(mult, 5))) | |
modified += 1 | |
break | |
total += 1 | |
print(f"Moved {modified}/{total} images and caption files. {zeroed} images with low count and tag '{FILTER_TAG}'") | |
def gen_stats(args): | |
tags_dict = {} | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))): | |
img = get_caption_file_image(txt) | |
if img: | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = set([t.strip().lower() for t in f.read().split(",")]) | |
if os.path.exists(os.path.join(os.path.split(txt)[0], "multiply.txt")): | |
mp = os.path.join(os.path.split(txt)[0], "multiply.txt") | |
with open(mp, "r", encoding="utf-8") as f: | |
mult = float(f.read()) | |
else: | |
mult = 1.0 | |
for tag in these_tags: | |
if tag not in tags_dict: | |
tags_dict[tag] = [mult, mult] | |
else: | |
tags_dict[tag] = [tags_dict[tag][0] + mult, (tags_dict[tag][1] + mult) / 2] | |
#sorted_ftags = [(k, round(tags_dict[k][0], 2), round(tags_dict[k][1], 2)) for k in tags_dict] | |
sorted_ftags = [(k, math.ceil(tags_dict[k][0])) for k in tags_dict] | |
sorted_costume_ftags = [(k, math.ceil(tags_dict[k][0])) for k in tags_dict if "costume" in k] | |
sorted_ftags.sort(key=lambda sorted: sorted[1], reverse=True) | |
sorted_costume_ftags.sort(key=lambda sorted: sorted[1], reverse=True) | |
with open(os.path.join(args.path, "multiplied-tag-count.txt"), "w", encoding='utf-8') as f: | |
json.dump(sorted_ftags, f, indent=2) | |
with open(os.path.join(args.path, "multiplied-tag-count-costumes.txt"), "w", encoding='utf-8') as f: | |
json.dump(sorted_costume_ftags, f, indent=2) | |
def contains_costumes(tags, danbooru_tags): | |
for t in tags: | |
if "costume" in t: | |
return True | |
return False | |
def contains_humans(tags, danbooru_tags=None): | |
if "furry" in tags: | |
return False | |
if "human" in tags or "not furry" in tags: | |
return False # already tagged | |
for t in HUMAN_TAGS: | |
if t in tags: | |
return True | |
return False | |
def has_matching_costume(tags1, tags2): | |
for t in tags1: | |
if "costume" in t: | |
if t in tags2: | |
return True | |
return False | |
def move_name_to_front(tags, danbooru_tags): | |
these_tags = list(to_danbooru_tag(t) for t in tags) | |
for t in these_tags: | |
this_category = danbooru_tags.get(t) | |
if this_category is not None: | |
if this_category["tag_category"] == 4: | |
these_tags.insert(0, these_tags.pop(these_tags.index(t))) | |
these_tags = [convert_tag(t) for t in these_tags] | |
return these_tags | |
def move_costume_to_front(tags, danbooru_tags): | |
new_tags = tags.copy() | |
for t in tags: | |
if "costume)" in t or "costume" in t: | |
new_tags.insert(0, new_tags.pop(new_tags.index(t))) | |
#print(tags) | |
#print("NEW:", new_tags) | |
return new_tags | |
def contains_character_tag(tags, danbooru_tags): | |
these_tags = list(to_danbooru_tag(t) for t in tags) | |
for t in these_tags: | |
this_category = danbooru_tags.get(t) | |
if this_category is not None: | |
if this_category["tag_category"] == 4: | |
return True | |
return False | |
def recalc_multiply_from_folders(args): | |
dirs = [] | |
modified = 0 | |
if args.move_extra_to: | |
move_extra_with_ae(args) | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))): | |
img = get_caption_file_image(txt) | |
if img: | |
dir = os.path.split(txt)[0] | |
if dir not in dirs: | |
file_count = len(list(gather_captioned_images(dir))) | |
if file_count > 0: | |
mult = 1.0 | |
if file_count < MULTIPLY_LOW_THR: | |
mult = BELOW_THR_MULTIPLIER | |
elif file_count < MULTIPLY_LOW_TARGET: | |
mult = min(MULTIPLY_LOW_TARGET / file_count, MULTIPLY_MAX_MULTIPLIER) | |
elif file_count > MULTIPLY_HIGH_THR: | |
mult = max(MULTIPLY_HIGH_THR / file_count, MULTIPLY_MIN_MULTIPLIER) | |
mult_path = os.path.join(dir, "multiply.txt") | |
with open(mult_path, "w", encoding="utf-8") as f: | |
f.write(str(round(mult, 4))) | |
modified += 1 | |
else: | |
# os.remove(mult_path) | |
print("Empty folder: ", dir) | |
dirs.append(dir) | |
print(f"Changed {modified} multipliers") | |
def get_jpg_quality(pim: 'Image.Image') -> int: | |
""" | |
Implement quality computation following ImageMagick heuristic algorithm: | |
https://github.com/ImageMagick/ImageMagick/blob/7.1.0-57/coders/jpeg.c#L782 | |
Usage: | |
``` | |
pim = Img.open(...) | |
quality = get_jpg_quality(pim) | |
``` | |
See also https://stackoverflow.com/questions/4354543/ | |
""" | |
qsum = 0 | |
qdict = pim.quantization | |
for i, qtable in qdict.items(): | |
qsum += sum(qtable) | |
if len(qdict) >= 1: | |
qvalue = qdict[0][2]+qdict[0][53] | |
hash, sums = _HASH_1, _SUMS_1 | |
if len(qdict) >= 2: | |
qvalue += qdict[1][0]+qdict[1][-1] | |
hash, sums = _HASH_2, _SUMS_2 | |
for i in range(100): | |
if ((qvalue < hash[i]) and (qsum < sums[i])): | |
continue | |
if (((qvalue <= hash[i]) and (qsum <= sums[i])) or (i >= 50)): | |
return i+1 | |
break | |
return -1 | |
NUM_QUANT_TBLS = 4 | |
DCTSIZE2 = 64 | |
_HASH_2 = [ 1020, 1015, 932, 848, 780, 735, 702, 679, 660, 645, | |
632, 623, 613, 607, 600, 594, 589, 585, 581, 571, | |
555, 542, 529, 514, 494, 474, 457, 439, 424, 410, | |
397, 386, 373, 364, 351, 341, 334, 324, 317, 309, | |
299, 294, 287, 279, 274, 267, 262, 257, 251, 247, | |
243, 237, 232, 227, 222, 217, 213, 207, 202, 198, | |
192, 188, 183, 177, 173, 168, 163, 157, 153, 148, | |
143, 139, 132, 128, 125, 119, 115, 108, 104, 99, | |
94, 90, 84, 79, 74, 70, 64, 59, 55, 49, | |
45, 40, 34, 30, 25, 20, 15, 11, 6, 4, | |
0 ] | |
_SUMS_2 = [ 32640, 32635, 32266, 31495, 30665, 29804, 29146, 28599, 28104, | |
27670, 27225, 26725, 26210, 25716, 25240, 24789, 24373, 23946, | |
23572, 22846, 21801, 20842, 19949, 19121, 18386, 17651, 16998, | |
16349, 15800, 15247, 14783, 14321, 13859, 13535, 13081, 12702, | |
12423, 12056, 11779, 11513, 11135, 10955, 10676, 10392, 10208, | |
9928, 9747, 9564, 9369, 9193, 9017, 8822, 8639, 8458, | |
8270, 8084, 7896, 7710, 7527, 7347, 7156, 6977, 6788, | |
6607, 6422, 6236, 6054, 5867, 5684, 5495, 5305, 5128, | |
4945, 4751, 4638, 4442, 4248, 4065, 3888, 3698, 3509, | |
3326, 3139, 2957, 2775, 2586, 2405, 2216, 2037, 1846, | |
1666, 1483, 1297, 1109, 927, 735, 554, 375, 201, | |
128, 0 ] | |
_HASH_1 = [ 510, 505, 422, 380, 355, 338, 326, 318, 311, 305, | |
300, 297, 293, 291, 288, 286, 284, 283, 281, 280, | |
279, 278, 277, 273, 262, 251, 243, 233, 225, 218, | |
211, 205, 198, 193, 186, 181, 177, 172, 168, 164, | |
158, 156, 152, 148, 145, 142, 139, 136, 133, 131, | |
129, 126, 123, 120, 118, 115, 113, 110, 107, 105, | |
102, 100, 97, 94, 92, 89, 87, 83, 81, 79, | |
76, 74, 70, 68, 66, 63, 61, 57, 55, 52, | |
50, 48, 44, 42, 39, 37, 34, 31, 29, 26, | |
24, 21, 18, 16, 13, 11, 8, 6, 3, 2, | |
0 ] | |
_SUMS_1 = [ | |
16320, 16315, 15946, 15277, 14655, 14073, 13623, 13230, 12859, | |
12560, 12240, 11861, 11456, 11081, 10714, 10360, 10027, 9679, | |
9368, 9056, 8680, 8331, 7995, 7668, 7376, 7084, 6823, | |
6562, 6345, 6125, 5939, 5756, 5571, 5421, 5240, 5086, | |
4976, 4829, 4719, 4616, 4463, 4393, 4280, 4166, 4092, | |
3980, 3909, 3835, 3755, 3688, 3621, 3541, 3467, 3396, | |
3323, 3247, 3170, 3096, 3021, 2952, 2874, 2804, 2727, | |
2657, 2583, 2509, 2437, 2362, 2290, 2211, 2136, 2068, | |
1996, 1915, 1858, 1773, 1692, 1620, 1552, 1477, 1398, | |
1326, 1251, 1179, 1109, 1031, 961, 884, 814, 736, | |
667, 592, 518, 441, 369, 292, 221, 151, 86, | |
64, 0 ] | |
def estimate_noise(I): | |
H, W = I.shape | |
M = [[1, -2, 1], | |
[-2, 4, -2], | |
[1, -2, 1]] | |
sigma = np.sum(np.sum(np.absolute(scipy.signal.convolve2d(I, M)))) | |
sigma = sigma * math.sqrt(0.5 * math.pi) / (6 * (W-2) * (H-2)) | |
return sigma | |
def clean(args): | |
from cv2.ximgproc import guidedFilter # pip install opencv-contrib-python | |
cleaned = 0 | |
for ext in IMAGE_EXTS: | |
for i in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))): | |
#img = cv2.imread(i).astype(np.float32) | |
img = cv2.imdecode(np.fromfile(i, np.uint8), flags=cv2.IMREAD_COLOR) | |
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
if estimate_noise(img_gray) > 2: | |
#print(f"{i}") | |
y = img.copy() | |
for _ in range(56): # orig: 64 | |
y = cv2.bilateralFilter(y, 5, 8, 8) | |
for _ in range(4): # orig: 4 | |
y = guidedFilter(img, y, 4, 16) | |
is_success, im_buf_arr = cv2.imencode(".png", img) | |
im_buf_arr.tofile(os.path.splitext(i)[0] + ".png") | |
#cv2.imwrite(os.path.splitext(i)[0] + ".png", y.clip(0, 255).astype(np.uint8)) | |
cleaned += 1 | |
if os.path.splitext(i)[1] != ".png": | |
os.remove(i) | |
print(f"Cleaned {cleaned} images.") | |
def remove_high_ar(args): | |
total = 0 | |
hist = 0 | |
converted = 0 | |
resized = 0 | |
lowquality = 0 | |
small = 0 | |
bad_ar = 0 | |
to_remove = [] | |
for ext in IMAGE_EXTS: | |
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))): | |
# txt = os.path.splitext(img)[0] + ".txt" | |
try: | |
with Image.open(img) as im: | |
ar = im.width / im.height | |
# if ar < 0.34 or ar > 2.99: | |
# if ar < 0.4 or ar > 2.5: | |
if ar < 0.49 or ar > 2.04: | |
to_remove.append(img) | |
bad_ar += 1 | |
total += 1 | |
#elif (im.width < MIN_RES) or (im.height < MIN_RES): | |
elif (im.width * im.height) < MIN_PIXELS: | |
to_remove.append(img) | |
small += 1 | |
total += 1 | |
else: | |
if (im.width * im.height) > 16000000: | |
if not (im.width == 4000 or im.height == 4000): | |
im = ImageOps.contain(im, (4000, 4000), Image.Resampling.LANCZOS) | |
im.convert("RGB").save(os.path.splitext(img)[0] + ".png", format='PNG', compress_level=4) | |
resized +=1 | |
if os.path.splitext(img)[1] != ".png": | |
im.close() | |
os.remove(img) | |
# continue | |
elif im.format == 'JPEG' and (im.width < 1200 or im.height < 1200) and get_jpg_quality(im) < 80: # and (im.width < 1200 or im.height < 1200) | |
#to_remove.append(img) | |
im = imgutils.restore.scunet.restore_with_scunet(im) | |
im.save(os.path.splitext(img)[0] + ".png", format='PNG', compress_level=4) | |
if os.path.splitext(img)[1] != ".png": | |
im.close() | |
os.remove(img) | |
#total += 1 | |
lowquality += 1 | |
elif ('transparency' in im.info): | |
#im = im.convert("RGBA") | |
#new_image = Image.new("RGBA", im.size, (128,128,128)) | |
#new_image.paste(im, mask=im) | |
im.convert("RGB").save(os.path.splitext(img)[0] + ".png", format="PNG") | |
#im = new_image | |
converted += 1 | |
if os.path.splitext(img)[1] != ".png": | |
im.close() | |
os.remove(img) | |
# continue | |
#if im.histogram()[-1] > (0.6 * im.width * im.height): | |
# #new_img = ImageOps.autocontrast(im.convert("RGB"), preserve_tone=True) | |
# #new_img.save(f"{img}", format='JPEG', quality=100, subsampling=0) | |
# #im = new_img | |
# hist += 1 | |
except Exception as ex: | |
print(f"Exception: {ex} {img} ") | |
# to_remove.append(img) | |
# total += 1 | |
continue | |
for img in to_remove: | |
try: | |
os.remove(img) | |
os.remove(os.path.splitext(img)[0] + ".txt") | |
except Exception as ex: | |
print(f"Can't remove': {ex} {img} ") | |
print(f"Removed {total} images. Restored {lowquality} low quality JPEGs, Removed {small} small images, {bad_ar} bad AR images. Found {hist} low contrast images. Converted {converted} transparent images. Resized {resized} images. ") | |
def normalized(a, axis=-1, order=2): | |
l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) | |
l2[l2 == 0] = 1 | |
return a / np.expand_dims(l2, axis) | |
def append_tag(txt, tag): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
if tag not in these_tags: | |
these_tags.append(tag) | |
with open(txt, "w", encoding="utf-8") as f: | |
f.write(", ".join(these_tags)) | |
def remove_low_ae_score(args): | |
from transformers import pipeline | |
import torch | |
import pytorch_lightning as pl | |
import torch.nn as nn | |
class MLP(pl.LightningModule): | |
def __init__(self, input_size, xcol='emb', ycol='avg_rating'): | |
super().__init__() | |
self.input_size = input_size | |
self.xcol = xcol | |
self.ycol = ycol | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 1024), | |
#nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
#nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
#nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
#nn.ReLU(), | |
nn.Linear(16, 1) | |
) | |
def forward(self, x): | |
return self.layers(x) | |
total = 0 | |
high_ae = 0 | |
low_ae = 0 | |
device = "cuda" | |
# warnings.filterwarnings("ignore", category=UserWarning) | |
if args.type == "wd": # scores 0 - 1.0 | |
pipe_aesthetic = pipeline("image-classification", "cafeai/cafe_aesthetic", device=0, batch_size=2) | |
if args.thr == 0: | |
args.thr = 0.55 | |
elif args.type == "clip": # scores 0 - 10.0 | |
model = MLP(768) | |
s = torch.load("sac+logos+ava1-l14-linearMSE.pth") | |
model.load_state_dict(s) | |
model.to(device) | |
model.eval() | |
model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64 | |
if args.thr == 0: | |
args.thr = 4.99 | |
elif args.type == "aa": # scores 0 - 1.0 | |
if args.thr == 0: | |
args.thr = 0.1 | |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") | |
model = rt.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
def predict(img): | |
img = img.astype(np.float32) / 255 | |
s = 768 | |
h, w = img.shape[:-1] | |
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) | |
ph, pw = s - h, s - w | |
img_input = np.zeros([s, s, 3], dtype=np.float32) | |
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) | |
img_input = np.transpose(img_input, (2, 0, 1)) | |
img_input = img_input[np.newaxis, :] | |
pred = model.run(None, {"img": img_input})[0].item() | |
return pred | |
folder_name = args.folder_name | |
outpath = os.path.join(args.path, folder_name) | |
for ext in IMAGE_EXTS: | |
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))): | |
txt = os.path.splitext(img)[0] + ".txt" | |
pil = Image.open(img) | |
if args.type == "wd": | |
try: | |
data = pipe_aesthetic(pil, top_k=2) | |
except Exception as ex: | |
print(f"Exception: {ex} - {img}") | |
final = {} | |
for d in data: | |
final[d["label"]] = d["score"] | |
sc = final["aesthetic"] | |
elif args.type == "aa": | |
img_a = np.array(pil.convert("RGB")) | |
try: | |
sc = predict(img_a) | |
if sc < 0.2: | |
#append_tag(txt, "low-ae") | |
low_ae += 1 | |
elif sc > 0.89: | |
#append_tag(txt, "high-ae") | |
high_ae += 1 | |
except Exception as ex: | |
print(f"Exception: {ex} - {img}") | |
else: | |
try: | |
image = preprocess(pil).unsqueeze(0).to(device) | |
if args.type == "clip": | |
with torch.no_grad(): | |
image_features = model2.encode_image(image) | |
im_emb_arr = normalized(image_features.cpu().detach().numpy() ) | |
prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor)) | |
sc = prediction[0][0].item() | |
# print(f"{img} - {sc}") | |
except Exception as ex: | |
print(f"Exception: {ex} - {img}") | |
if sc < args.thr: | |
pil.close() | |
do_move(txt, img, outpath, args) | |
total += 1 | |
print(f"{total} images with low aesthetic score moved. {high_ae} images with high ae score, {low_ae} images with low ae score ") | |
warnings.filterwarnings("default") | |
def remove_blacklisted(args): | |
total = 0 | |
files_to_remove = [] | |
ttags = args.tags | |
if args.convert_tags: | |
ttags = [convert_tag(t) for t in args.tags] | |
print(ttags) | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=args.recursive))): | |
if get_caption_file_image(txt): | |
with open(txt, "r", encoding="utf-8") as f: | |
these_tags = [t.strip().lower() for t in f.read().split(",")] | |
if args.tags: | |
for btag in ttags: | |
if btag in these_tags: | |
files_to_remove.append(txt) | |
total += 1 | |
break | |
else: | |
for btag in BAD_GENERAL: | |
if btag in these_tags: | |
files_to_remove.append(txt) | |
total += 1 | |
break | |
if "nsfw" in these_tags: | |
exclude = True | |
for btag in BAD_NSFW: | |
if btag in these_tags: | |
exclude = False | |
for itag in BAD_NSFW_EXCLUDE: | |
if itag in these_tags: | |
exclude = True | |
if exclude == False: | |
files_to_remove.append(txt) | |
total += 1 | |
for txt in files_to_remove: | |
try: | |
os.remove(get_caption_file_image(txt)) | |
os.remove(txt) | |
except Exception as ex: | |
print(f"Exception: {ex}") | |
continue | |
print(f"Removed {total} images with blacklisted tags.") | |
def validate(args): | |
problems = [] | |
total = 0 | |
print("Validating folder names...") | |
for dirname in tqdm.tqdm(os.listdir(args.path)): | |
path = os.path.join(args.path, dirname) | |
if os.path.isdir(path): | |
# m = repeats_folder_re.match(dirname) | |
# if not m: | |
# problems.add((path, "Folder is not in \"5_concept\" format")) | |
# continue | |
img_count = len(list(glob.iglob(os.path.join(path, "*.txt")))) | |
if img_count == 0: | |
problems.append((path, "Folder contains no captions")) | |
print("Validating image files...") | |
for ext in IMAGE_EXTS: | |
for img in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, f"**/*{ext}"), recursive=True))): | |
txt = os.path.splitext(img)[0] + ".txt" | |
if not os.path.isfile(txt): | |
problems.append((img, "Image file is missing caption")) | |
#try: | |
# pil = Image.open(img) | |
# pil.load() | |
#except Exception as ex: | |
# problems.append((img, f"Failed to open image file: {ex}. Image removed")) | |
# os.remove(img) | |
# continue | |
print("Validating captions...") | |
for txt in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.txt"), recursive=True))): | |
total += 1 | |
images = get_caption_file_images(txt) | |
if not images: | |
if (os.path.split(txt)[1] != "multiply.txt"): | |
problems.append((txt, "Caption file is missing corresponding image")) | |
os.rename(txt, txt + ".bak") | |
# print(f"Removed {txt}") | |
continue | |
elif len(images) > 1: | |
problems.append((txt, "Caption file has more than one corresponding image")) | |
continue | |
with open(txt, "r", encoding="utf-8") as f: | |
tag_string = f.read().strip() | |
if not tag_string: | |
problems.append((txt, "Caption file is empty")) | |
continue | |
if "\n" in tag_string: | |
problems.append((txt, "Caption file contains newlines")) | |
# if "_" in tag_string: | |
# problems.append((txt, "Caption file contains underscores")) | |
tags = {t.strip().lower(): True for t in tag_string.split(",")} | |
if not tags: | |
problems.append((txt, "Caption file has no tags")) | |
# elif any(not t for t in tags.keys()): | |
# problems.append((txt, "Caption file contains at least one blank tag")) | |
print("Validating latents...") | |
for npz in tqdm.tqdm(list(glob.iglob(os.path.join(args.path, "**/*.npz"), recursive=True))): | |
txt = os.path.splitext(npz)[0] + ".txt" | |
images = get_caption_file_images(npz) | |
if not os.path.isfile(txt): | |
problems.append((npz, "Latent is missing caption")) | |
os.rename(npz, npz + ".bak") | |
elif not images: | |
problems.append((npz, "Latent is missing corresponding image")) | |
os.rename(npz, npz + ".bak") | |
if problems: | |
for filename, problem in problems: | |
print(f"{filename} - {problem}") | |
return 1 | |
print(f"No problems found for {total} image/caption pairs.") | |
return 0 | |
def stats(args): | |
return | |
#problems = [] | |
#total_images = 0 | |
#total_seen = 0 | |
#rows = [["folder name", "repeats", "image count", "total seen"]] | |
#for dirname in os.listdir(args.path): | |
# path = os.path.join(args.path, dirname) | |
# if os.path.isdir(path): | |
# m = repeats_folder_re.match(dirname) | |
# repeats, folder_name = int(m.group(1)), m.group(2) | |
# img_count = len(list(glob.iglob(os.path.join(path, "*.txt")))) | |
# rows.append([dirname, repeats, img_count, repeats * img_count]) | |
# total_images += img_count | |
# total_seen += img_count * repeats | |
#rows.append(["(Total)", "", total_images, total_seen]) | |
#col_width = max(len(str(word)) for row in rows for word in row) + 2 | |
#for i, row in enumerate(rows): | |
# print("".join(str(word).ljust(col_width) for word in row)) | |
# if i == 0: | |
# print(("=" * (col_width - 1) + " ") * len(rows[0])) | |
# elif i == len(rows) - 2: | |
# print(("-" * (col_width - 1) + " ") * len(rows[0])) | |
def main(args): | |
if args.command == "fixup": | |
return fixup(args) | |
elif args.command == "add": | |
return add(args) | |
elif args.command == "remove": | |
return remove(args) | |
elif args.command == "replace": | |
return replace(args) | |
elif args.command == "strip_suffix": | |
return strip_suffix(args) | |
elif args.command == "move_to_front": | |
return move_to_front(args) | |
elif args.command == "move_categories_to_front": | |
return move_categories_to_front(args) | |
elif args.command == "organize": | |
return organize(args) | |
elif args.command == "validate": | |
return validate(args) | |
elif args.command == "stats": | |
return stats(args) | |
elif args.command == "remove_high_ar": | |
return remove_high_ar(args) | |
elif args.command == "remove_blacklisted": | |
return remove_blacklisted(args) | |
elif args.command == "gen_multiply": | |
return gen_multiply(args) | |
elif args.command == "gen_stats": | |
return gen_stats(args) | |
elif args.command == "recalc_multiply_from_folders": | |
return recalc_multiply_from_folders(args) | |
elif args.command == "remove_low_ae_score": | |
return remove_low_ae_score(args) | |
elif args.command == "name_stats": | |
return name_stats(args) | |
elif args.command == "clean": | |
return clean(args) | |
else: | |
parser.print_help() | |
return 1 | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
parser.exit(main(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment