Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created May 17, 2024 12:57
Show Gist options
  • Save kohya-ss/4de9ab8cd3f9056ccd59957d87fe8882 to your computer and use it in GitHub Desktop.
Save kohya-ss/4de9ab8cd3f9056ccd59957d87fe8882 to your computer and use it in GitHub Desktop.
WD14 Taggerでタグごとの確信度を取得する
import argparse
import csv
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
# from wd14 tagger
IMAGE_SIZE = 448
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-vit-tagger-v3"
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def main(args):
print("Loading wd14 tagger from Hugging Face")
repo_id = args.repo_id
onnx_path = hf_hub_download(repo_id, "model.onnx")
csv_path = hf_hub_download(repo_id, "selected_tags.csv")
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
print(f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}")
args.batch_size = batch_size
del model
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{"device_type": "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
),
)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.jpeg"))
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.png"))
image_paths += glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
os.makedirs(args.output_dir, exist_ok=True)
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
for (image_path, _), prob in zip(path_imgs, probs):
tag_confidences = {}
# rating tags
for i in range(4):
tag_confidences[rating_tags[i]] = prob[i]
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if p >= args.thresh:
if i < len(general_tags):
tag_name = general_tags[i]
else:
tag_name = character_tags[i - len(general_tags)]
tag_confidences[tag_name] = p
caption_file = os.path.splitext(image_path)[0] + ".csv"
caption_file = os.path.join(args.output_dir, os.path.basename(caption_file))
with open(caption_file, "wt", encoding="utf-8") as f:
writer = csv.writer(f, lineterminator="\n")
writer.writerow(["tag", "confidence"])
for tag, confidence in tag_confidences.items():
writer.writerow([tag, confidence])
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID, default: "
+ DEFAULT_WD14_TAGGER_REPO,
)
parser.add_argument("--batch_size", type=int, default=16, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--thresh",
type=float,
default=0.35,
help="threshold of confidence to add a tag / タグを追加するか判定する閾値, default: 0.35",
)
parser.add_argument(
"--output_dir",
type=str,
default=".",
help="output directory for tag confidence csv files / タグの確信度のCSVファイルの出力ディレクトリ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment