Created
January 3, 2023 23:04
-
-
Save r7vz9h3/008e86bed8a4cbe4079e644ac71ccfc3 to your computer and use it in GitHub Desktop.
Dumb script to index and search prompts of Stable Diffusion images from A1111's WebUI. Supports PNG, JPG, and WEBP.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import argparse | |
import glob | |
import re | |
import piexif | |
import piexif.helper | |
import warnings | |
import pickle | |
from PIL import Image | |
from tqdm import tqdm | |
# !!!!!!!!!!!!!!!! # | |
# HARD-CODED PATHS # | |
output_path = "/path/to/Stable Diffusion/outputs/" | |
pickle_folder = "/path/where/to/save/pickles/" | |
# !!!!!!!!!!!!!!!! # | |
parser = argparse.ArgumentParser(description='Find images with prompts matching all parts of the search phrase.') | |
parser.add_argument('-r', '--rebuild-index', dest='rebuild_index', action='store_true', help='rebuild the search index') | |
parser.add_argument('-b', '--backslashes', action='store_true', help='use "\\" instead of "/" in displayed paths') | |
parser.add_argument('search_phrase', type=str, nargs='*', | |
help='search phrase: tags separated with spaces, underscores will be interpreted as literal spaces') | |
args = parser.parse_args() | |
if not args.rebuild_index and len(args.search_phrase) == 0: | |
raise RuntimeError("No arguments") | |
search_elems = [x.replace('_',' ').lower() for x in args.search_phrase] | |
### | |
# copied from | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py | |
def run_pnginfo(image): | |
if image is None: | |
return '', '', '' | |
items = image.info | |
geninfo = '' | |
if "exif" in image.info: | |
exif = piexif.load(image.info["exif"]) | |
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') | |
try: | |
exif_comment = piexif.helper.UserComment.load(exif_comment) | |
except ValueError: | |
exif_comment = exif_comment.decode('utf8', errors="ignore") | |
items['exif comment'] = exif_comment | |
geninfo = exif_comment | |
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', | |
'loop', 'background', 'timestamp', 'duration']: | |
items.pop(field, None) | |
geninfo = items.get('parameters', geninfo) | |
info = '' | |
for key, text in items.items(): | |
info += f"""{str(key)}\n{str(text)}""".strip()+"\n" | |
if len(info) == 0: | |
message = "Nothing found in the image." | |
info = f"<div><p>{message}<p></div>" | |
return '', geninfo, info | |
### | |
# The important part that reads image info, also for WEBP | |
def get_prompt(image_path): | |
if image_path.endswith('.webp'): | |
with open(image_path, 'rb') as imfile: | |
info = imfile.read().decode('utf-8', 'ignore').rpartition('UNICODE')[2].replace('\x00', '') | |
elif any(image_path.endswith(ext) for ext in ('.png','.jpg','.jpeg')): | |
info = run_pnginfo(Image.open(image_path))[1] | |
else: | |
# warnings.warn("Filetype not supported: " + image_path, RuntimeWarning) | |
return None | |
if "Negative prompt:" in info: | |
prompt = info.partition("Negative prompt:")[0] | |
else: | |
prompt = info.rpartition('\n')[0] | |
prompt = prompt.replace('\n', ' ') | |
prompt = re.sub(":[0-9\.]+([\), ]|$)", "", prompt) | |
prompt = prompt.translate(str.maketrans('', '', '()[]{}:,')) | |
prompt = prompt.lower() | |
return prompt | |
### | |
if args.rebuild_index: | |
prompts_list, paths_list, error_list = [], [], [] | |
images_glob = glob.glob(output_path+"*-images/*/*") | |
output_path_noslash = output_path.rstrip("/") | |
if len(images_glob) == 0: | |
raise RuntimeError("No images found in "+output_path) | |
for file_path in tqdm(images_glob): | |
prompt = get_prompt(file_path) | |
if prompt is None or prompt == '': | |
continue | |
if len(prompt) > 10240: | |
# warnings.warn(f"very long prompt, {len(prompt)} characters: " + file_path, RuntimeWarning) | |
error_list.append(file_path) | |
continue | |
if prompt in prompts_list: | |
continue | |
paths_list.append(file_path.replace('\\','/').replace(output_path_noslash, '', 1)) | |
prompts_list.append(prompt) | |
with open(pickle_folder+"prompts_list paths_list.pickle", "wb") as picklefile: | |
pickle.dump([prompts_list, paths_list], picklefile) | |
print("Pickle saved: "+pickle_folder+"prompts_list paths_list.pickle") | |
else: | |
with open(pickle_folder+"prompts_list paths_list.pickle", "rb") as picklefile: | |
prompts_list, paths_list = pickle.load(picklefile) | |
### | |
if len(search_elems) > 0: | |
matches = [] | |
match_folder_dict = dict() | |
for prompt, path in zip(prompts_list, paths_list): | |
if all(elem in prompt for elem in search_elems): | |
matches.append([path, prompt]) | |
folder,_,filename = path.rpartition('/') | |
if folder not in match_folder_dict: | |
match_folder_dict[folder] = [] | |
match_folder_dict[folder].append(filename) | |
print(f"{len(matches)} matches found!") | |
### | |
for folder,files in match_folder_dict.items(): | |
if args.backslashes: | |
folder_bs = folder.replace('/','\\') | |
print(f"\n{folder_bs}"+"\\") | |
else: | |
print(f"\n{folder}/") | |
files_string = "\t" | |
for f in files: | |
files_string += f"{f}, " | |
if len(files_string) > 80: | |
print(files_string.rstrip(", ")) | |
files_string = "\t" | |
if files_string != "\t": | |
print(files_string.rstrip(", ")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment