Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active June 14, 2024 02:36
Show Gist options
  • Save tori29umai0123/f2c08d5c0a1dffb1b38cce8185651d2b to your computer and use it in GitHub Desktop.
Save tori29umai0123/f2c08d5c0a1dffb1b38cce8185651d2b to your computer and use it in GitHub Desktop.
ContentSafetyAnalyzer.py
import csv
import os
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download
# 画像のサイズ設定
IMAGE_SIZE = 448
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # BGRからRGBへ変換
# 画像を正方形にするためのパディングを追加
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
# サイズに合わせた補間方法を選択
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags, thresh):
try:
image = Image.open(image_path)
image = image.convert("RGB") if image.mode != "RGB" else image
image = preprocess_image(image)
except Exception as e:
print(f"画像を読み込めません: {image_path}, エラー: {e}")
return
img = np.array([image])
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
# NSFW/SFW判定
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
max_sfw_score = tag_confidences.get("general", 0)
if max_nsfw_score > max_sfw_score:
print("NSFWの可能性が高いです")
else:
print("SFWの可能性が高いです")
# 版権キャラクターの可能性を評価
character_tags_with_probs = []
for i, p in enumerate(prob[4:]):
if p >= thresh and i >= len(general_tags):
tag_index = i - len(general_tags)
if tag_index < len(character_tags):
tag_name = character_tags[tag_index]
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
if character_tags_with_probs:
print(f"版権キャラクター: {character_tags_with_probs}の可能性があります")
else:
print("版権キャラクターの可能性が低いと思われます")
def main(MODEL_ID, image_path, thresh):
print("Hugging Faceからモデルをダウンロード中")
onnx_path = hf_hub_download(MODEL_ID, "model.onnx")
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv")
print("ONNXモデルを実行中")
print(f"ONNXモデルのパス: {onnx_path}")
ort_sess = ort.InferenceSession(onnx_path)
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
rows = list(reader)
assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}"
rating_tags = [row[1] for row in rows if row[2] == "9"]
character_tags = [row[1] for row in rows if row[2] == "4"]
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags, thresh)
print("処理完了!")
if __name__ == "__main__":
MODEL_ID = "SmilingWolf/wd-swinv2-tagger-v3"
image_path = "E:/desktop/test.jpg" # 画像のパス
thresh = 0.35 # 閾値の設定
main(MODEL_ID, image_path, thresh)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment