Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active June 11, 2024 13:02
Show Gist options
  • Save tori29umai0123/615eb806832fba83025912cfc82008bb to your computer and use it in GitHub Desktop.
Save tori29umai0123/615eb806832fba83025912cfc82008bb to your computer and use it in GitHub Desktop.
nsfw_filter.py
import argparse
import csv
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import shutil
# 画像のサイズ
IMAGE_SIZE = 448
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # BGRからRGBに変換
# 正方形にパディング
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def run_batch(path_imgs, input_name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir):
imgs = np.array([im for _, im in path_imgs])
probs = ort_sess.run(None, {input_name: imgs})[0] # onnxの出力
probs = probs[: len(path_imgs)]
for (image_path, _), prob in zip(path_imgs, probs):
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
max_sfw_score = tag_confidences.get("general", 0)
destination = nsfw_dir if max_nsfw_score > max_sfw_score else sfw_dir
# 適切なフォルダに画像をコピー
try:
shutil.copy(image_path, os.path.join(destination, os.path.basename(image_path)))
print(f"{image_path}{destination} にコピーしました。")
except Exception as e:
print(f"{image_path}{destination} にコピーできませんでした。エラー: {e}")
def main():
print("Hugging Faceからwd14 taggerをロード中")
onnx_path = hf_hub_download(MODEL_ID, "model.onnx")
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv")
print("wd14 taggerでonnxを実行")
print(f"onnxモデルをロード中: {onnx_path}")
ort_sess = ort.InferenceSession(onnx_path)
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader) # ヘッダー行を読む
rows = list(reader)
assert header == ["tag_id", "name", "category", "count"], f"予期しないCSVフォーマット: {header}"
rating_tags = [row[1] for row in rows if row[2] == "9"]
#版権フィルター用範囲
#character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
image_paths = glob.glob(os.path.join(input_dir, "*.*"))
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
try:
image = Image.open(image_path)
image = image.convert("RGB") if image.mode != "RGB" else image
image = preprocess_image(image)
b_imgs.append((image_path, image))
except Exception as e:
print(f"画像を読み込めません: {image_path}, エラー: {e}")
continue
if len(b_imgs) >= batch_size:
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir)
b_imgs = []
if b_imgs:
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir)
print("処理完了!")
if __name__ == "__main__":
MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"
input_dir = "E:/desktop/dart/test"
sfw_dir = "E:/desktop/dart/test_sfw"
nsfw_dir = "E:/desktop/dart/test_nsfw"
batch_size = 16
thresh = 0.35
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment