Skip to content

Instantly share code, notes, and snippets.

@mcxiaoke
Last active December 13, 2023 06:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mcxiaoke/42faaf3baa77870f31df386d150c710c to your computer and use it in GitHub Desktop.
Save mcxiaoke/42faaf3baa77870f31df386d150c710c to your computer and use it in GitHub Desktop.
test checkpoints, a custom script for AUTOMATIC1111 / stable-diffusion-webui.

Test Checkpoints

introduction

test your checkpoints,

creating images for selected checkpoints,

a custom script for AUTOMATIC1111 / stable-diffusion-webui.

how to use:

download this script and put it into your stable diffusion scripts folder.

"""
Copyright 2023 -- Zhang Xiaoke github@mcxiaoke.com
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
test your checkpoints
creating images for selected checkpoints
a custom script for AUTOMATIC1111/stable-diffusion-webui
Script: https://gist.github.com/mcxiaoke/42faaf3baa77870f31df386d150c710c
Version: 1.0.0
Created at 2023.12.12
Created by https://github.com/mcxiaoke
"""
import os
import sys
import pathlib
from datetime import datetime
from modules.processing import Processed, process_images, images
from modules import sd_models, processing, shared
import modules.scripts as scripts
import gradio as gr
from collections import namedtuple
from random import randint
import itertools
import operator
import functools
import random
model_path = sd_models.model_path
UI_TITLE = "Test Checkpoints"
MODEL_EXT = [".ckpt", ".safetensors"]
# https://realpython.com/python-flatten-list/
def flatten_concatenation(matrix):
flat_list = []
for row in matrix:
flat_list += row
return flat_list
def flatten_extend(matrix):
flat_list = []
for row in matrix:
flat_list.extend(row)
return flat_list
def flatten_reduce_iconcat(matrix):
return functools.reduce(operator.iconcat, matrix, [])
def flatten_list_iter(nested_list):
for item in nested_list:
if isinstance(item, list):
yield from flatten_list(item)
else:
yield item
def get_files(paths):
paths = paths if isinstance(paths, list) else [paths]
filepaths = []
for path in paths:
for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
for filename in filenames:
filepaths.append(os.path.join(dirpath, filename))
return filepaths
def get_all_files(root_dir):
files = []
for dirpath in pathlib.Path(root_dir).iterdir():
if dirpath.is_file():
files.append(os.path.relpath(dirpath, start=root_dir))
elif dirpath.is_dir():
if dirpath.is_symlink():
dirpath = dirpath.resolve()
files.extend(get_all_files(dirpath))
return files
def get_subdirectories_w(root_dir):
for current_dir, subdirectories, _ in os.walk(root_dir, followlinks=True):
for subdir in subdirectories:
subdir_path = os.path.join(current_dir, subdir)
yield os.path.relpath(subdir_path, start=root_dir)
def get_subdirectories_p(root_dir):
for dirpath in pathlib.Path(root_dir).iterdir():
if dirpath.is_dir():
print("++", os.path.abspath(dirpath))
yield os.path.relpath(dirpath, start=root_dir)
elif dirpath.is_symlink():
dirpath = os.path.realpath(dirpath)
if dirpath.is_dir():
print("++", os.path.abspath(dirpath))
yield from get_subdirectories_p(dirpath)
def get_model_name(filename):
abspath = os.path.abspath(filename)
if abspath.startswith(model_path):
name = abspath.replace(model_path, "")
else:
name = os.path.basename(filename)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return model_name
def get_model_filename(filename):
name = os.path.basename(filename)
return os.path.splitext(name)[0]
def get_model_list(selected_dirs):
models = []
if selected_dirs is None or len(selected_dirs) == 0:
return models
# hack for all models
if selected_dirs[0] == "All":
selected_dirs = ["."]
selected_dirs = [os.path.join(model_path, x) for x in selected_dirs]
model_filenames = [list(shared.walk_files(d, MODEL_EXT)) for d in selected_dirs]
model_filenames = flatten_concatenation(model_filenames)
for f in model_filenames:
path = os.path.abspath(f)
name = get_model_name(path)
model = sd_models.get_closet_checkpoint_match(name)
if model is not None:
models.append(model)
# print(f"Model:", model.title)
return sorted(models, key=lambda x: x.name)
class Script(scripts.Script):
def title(self):
return UI_TITLE
def ui(self, is_img2img):
model_dirs = list(get_subdirectories_w(model_path))
model_dirs = [x.replace("\\", "/") for x in model_dirs]
model_dirs = [f"{x}/" for x in model_dirs]
model_dirs.insert(0, "All")
selected_dirs = gr.CheckboxGroup(
choices=model_dirs, label="Choose checkpoint folders"
)
batch_size = gr.Number(value=1, label="Batch size for every checkpoint")
random_seed = gr.Checkbox(label="All Random Seed", info="Use random seed for every test?")
return [selected_dirs, batch_size, random_seed]
def run(self, p, selected_dirs, batch_size, random_seed):
positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
if not positive_prompt:
# return process_images(p)
raise ValueError(f"{UI_TITLE}: Empty positive prompt!")
if not selected_dirs:
raise ValueError(f"{UI_TITLE}: No checkpoint folders selected!")
models = get_model_list(selected_dirs)
if not models:
raise ValueError(f"{UI_TITLE}: No checkpoints found!")
initial_seed = p.seed
if initial_seed == -1:
initial_seed = random.randrange(4294967294)
b_size = int(batch_size)
model_names = [m.name for m in models]
all_model_names = []
all_seeds = []
for m in model_names:
for i in range(b_size):
all_model_names.append(m)
if random_seed:
all_seeds.append(random.randrange(4294967294))
else:
all_seeds.append(initial_seed + i)
total_count = len(all_model_names)
print(
f"{UI_TITLE}: total {len(model_names)} checkpoints in folder: {selected_dirs}."
)
print(f"{UI_TITLE}: create {total_count} images in {len(model_names)} batches.")
if shared.state.job_count == -1:
shared.state.job_count = total_count
for i in range(total_count):
if shared.state.interrupted:
return processed
shared.state.job = f"{UI_TITLE} job {i+1} out of {total_count}"
p.override_settings["sd_model_checkpoint"] = all_model_names[i]
p.seed = all_seeds[i]
p.do_not_save_grid = True
print(
f"{UI_TITLE}: processing model:{all_model_names[i]} seed:{all_seeds[i]} ({i+1}/{total_count})"
)
if i == 0:
processed = process_images(p)
else:
appendimages = process_images(p)
processed.images.insert(0, appendimages.images[0])
processed.infotexts.insert(0, appendimages.infotexts[0])
return processed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment